Use MdnsDiscoveryManager for discovery

Register/Unregister the listener to/from MdnsDiscoveryManager
when discovery started/stopped.

Bug: 254166302
Test: atest FrameworksNetTests CtsNetTestsCases
Change-Id: Ibd782029826ac5856c608165928cd942e46dd9a4
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index b5a10e2..c93fcf6 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -20,6 +20,7 @@
 import static android.net.nsd.NsdManager.MDNS_SERVICE_EVENT;
 import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
 
+import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
 import android.content.Intent;
@@ -45,6 +46,7 @@
 import android.os.Message;
 import android.os.RemoteException;
 import android.os.UserHandle;
+import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
 import android.util.SparseArray;
@@ -58,6 +60,9 @@
 import com.android.server.connectivity.mdns.ExecutorProvider;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
 import com.android.server.connectivity.mdns.MdnsMultinetworkSocketClient;
+import com.android.server.connectivity.mdns.MdnsSearchOptions;
+import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
+import com.android.server.connectivity.mdns.MdnsServiceInfo;
 import com.android.server.connectivity.mdns.MdnsSocketClientBase;
 import com.android.server.connectivity.mdns.MdnsSocketProvider;
 
@@ -68,6 +73,9 @@
 import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.util.HashMap;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
 /**
  * Network Service Discovery Service handles remote service discovery operation requests by
@@ -79,6 +87,7 @@
     private static final String TAG = "NsdService";
     private static final String MDNS_TAG = "mDnsConnector";
     private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version";
+    private static final String LOCAL_DOMAIN_NAME = "local";
 
     private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final long CLEANUP_DELAY_MS = 10000;
@@ -94,10 +103,11 @@
     private final MdnsDiscoveryManager mMdnsDiscoveryManager;
     @Nullable
     private final MdnsSocketProvider mMdnsSocketProvider;
-    // WARNING : Accessing this value in any thread is not safe, it must only be changed in the
+    // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
     // state machine thread. If change this outside state machine, it will need to introduce
     // synchronization.
     private boolean mIsDaemonStarted = false;
+    private boolean mIsMonitoringSocketsStarted = false;
 
     /**
      * Clients receiving asynchronous messages
@@ -114,6 +124,73 @@
     // The count of the connected legacy clients.
     private int mLegacyClientCount = 0;
 
+    private static class MdnsListener implements MdnsServiceBrowserListener {
+        protected final int mClientId;
+        protected final int mTransactionId;
+        @NonNull
+        protected final NsdServiceInfo mReqServiceInfo;
+        @NonNull
+        protected final String mListenedServiceType;
+
+        MdnsListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
+                @NonNull String listenedServiceType) {
+            mClientId = clientId;
+            mTransactionId = transactionId;
+            mReqServiceInfo = reqServiceInfo;
+            mListenedServiceType = listenedServiceType;
+        }
+
+        @NonNull
+        public String getListenedServiceType() {
+            return mListenedServiceType;
+        }
+
+        @Override
+        public void onServiceFound(@NonNull MdnsServiceInfo serviceInfo) { }
+
+        @Override
+        public void onServiceUpdated(@NonNull MdnsServiceInfo serviceInfo) { }
+
+        @Override
+        public void onServiceRemoved(@NonNull MdnsServiceInfo serviceInfo) { }
+
+        @Override
+        public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) { }
+
+        @Override
+        public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) { }
+
+        @Override
+        public void onSearchStoppedWithError(int error) { }
+
+        @Override
+        public void onSearchFailedToStart() { }
+
+        @Override
+        public void onDiscoveryQuerySent(@NonNull List<String> subtypes, int transactionId) { }
+
+        @Override
+        public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode) { }
+    }
+
+    private class DiscoveryListener extends MdnsListener {
+
+        DiscoveryListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
+                @NonNull String listenServiceType) {
+            super(clientId, transactionId, reqServiceInfo, listenServiceType);
+        }
+
+        @Override
+        public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) {
+            // TODO: implement service name discovered callback.
+        }
+
+        @Override
+        public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) {
+            // TODO: implement service name removed callback.
+        }
+    }
+
     private class NsdStateMachine extends StateMachine {
 
         private final DefaultState mDefaultState = new DefaultState();
@@ -164,6 +241,31 @@
             this.removeMessages(NsdManager.DAEMON_CLEANUP);
         }
 
+        private void maybeStartMonitoringSockets() {
+            if (mIsMonitoringSocketsStarted) {
+                if (DBG) Log.d(TAG, "Socket monitoring is already started.");
+                return;
+            }
+
+            mMdnsSocketProvider.startMonitoringSockets();
+            mIsMonitoringSocketsStarted = true;
+        }
+
+        private void maybeStopMonitoringSockets() {
+            if (!mIsMonitoringSocketsStarted) {
+                if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
+                return;
+            }
+            mMdnsSocketProvider.stopMonitoringSockets();
+            mIsMonitoringSocketsStarted = false;
+        }
+
+        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
+            if (!isAnyRequestActive()) {
+                maybeStopMonitoringSockets();
+            }
+        }
+
         NsdStateMachine(String name, Handler handler) {
             super(name, handler);
             addState(mDefaultState);
@@ -195,11 +297,17 @@
                         final NsdServiceConnector connector = (NsdServiceConnector) msg.obj;
                         cInfo = mClients.remove(connector);
                         if (cInfo != null) {
+                            if (mMdnsDiscoveryManager != null) {
+                                cInfo.unregisterAllListeners();
+                            }
                             cInfo.expungeAllRequests();
                             if (cInfo.isLegacy()) {
                                 mLegacyClientCount -= 1;
                             }
                         }
+                        if (mMdnsDiscoveryManager != null) {
+                            maybeStopMonitoringSocketsIfNoActiveRequest();
+                        }
                         maybeScheduleStop();
                         break;
                     case NsdManager.DISCOVER_SERVICES:
@@ -251,6 +359,9 @@
                             maybeStartDaemon();
                         }
                         break;
+                    case NsdManager.MDNS_MONITORING_SOCKETS_CLEANUP:
+                        maybeStopMonitoringSockets();
+                        break;
                     default:
                         Log.e(TAG, "Unhandled " + msg);
                         return NOT_HANDLED;
@@ -300,6 +411,47 @@
                 maybeScheduleStop();
             }
 
+            private void storeListenerMap(int clientId, int transactionId, MdnsListener listener,
+                    ClientInfo clientInfo) {
+                clientInfo.mClientIds.put(clientId, transactionId);
+                clientInfo.mListeners.put(clientId, listener);
+                mIdToClientInfoMap.put(transactionId, clientInfo);
+                removeMessages(NsdManager.MDNS_MONITORING_SOCKETS_CLEANUP);
+            }
+
+            private void removeListenerMap(int clientId, int transactionId, ClientInfo clientInfo) {
+                clientInfo.mClientIds.delete(clientId);
+                clientInfo.mListeners.delete(clientId);
+                mIdToClientInfoMap.remove(transactionId);
+                maybeStopMonitoringSocketsIfNoActiveRequest();
+            }
+
+            /**
+             * Check the given service type is valid and construct it to a service type
+             * which can use for discovery / resolution service.
+             *
+             * <p> The valid service type should be 2 labels, or 3 labels if the query is for a
+             * subtype (see RFC6763 7.1). Each label is up to 63 characters and must start with an
+             * underscore; they are alphanumerical characters or dashes or underscore, except the
+             * last one that is just alphanumerical. The last label must be _tcp or _udp.
+             *
+             * @param serviceType the request service type for discovery / resolution service
+             * @return constructed service type or null if the given service type is invalid.
+             */
+            @Nullable
+            private String constructServiceType(String serviceType) {
+                if (TextUtils.isEmpty(serviceType)) return null;
+
+                final Pattern serviceTypePattern = Pattern.compile(
+                        "^(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\.)?"
+                                + "(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\._(?:tcp|udp))$");
+                final Matcher matcher = serviceTypePattern.matcher(serviceType);
+                if (!matcher.matches()) return null;
+                return matcher.group(1) == null
+                        ? serviceType + ".local"
+                        : matcher.group(1) + "._sub" + matcher.group(2) + ".local";
+            }
+
             @Override
             public boolean processMessage(Message msg) {
                 final ClientInfo clientInfo;
@@ -325,19 +477,40 @@
                             break;
                         }
 
-                        maybeStartDaemon();
+                        final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (discoverServices(id, args.serviceInfo)) {
-                            if (DBG) {
-                                Log.d(TAG, "Discover " + msg.arg2 + " " + id
-                                        + args.serviceInfo.getServiceType());
+                        if (mMdnsDiscoveryManager != null) {
+                            final String serviceType = constructServiceType(info.getServiceType());
+                            if (serviceType == null) {
+                                clientInfo.onDiscoverServicesFailed(clientId,
+                                        NsdManager.FAILURE_INTERNAL_ERROR);
+                                break;
                             }
-                            storeRequestMap(clientId, id, clientInfo, msg.what);
-                            clientInfo.onDiscoverServicesStarted(clientId, args.serviceInfo);
+
+                            maybeStartMonitoringSockets();
+                            final MdnsListener listener =
+                                    new DiscoveryListener(clientId, id, info, serviceType);
+                            final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
+                                    .setNetwork(info.getNetwork())
+                                    .setIsPassiveMode(true)
+                                    .build();
+                            mMdnsDiscoveryManager.registerListener(serviceType, listener, options);
+                            storeListenerMap(clientId, id, listener, clientInfo);
+                            clientInfo.onDiscoverServicesStarted(clientId, info);
                         } else {
-                            stopServiceDiscovery(id);
-                            clientInfo.onDiscoverServicesFailed(clientId,
-                                    NsdManager.FAILURE_INTERNAL_ERROR);
+                            maybeStartDaemon();
+                            if (discoverServices(id, info)) {
+                                if (DBG) {
+                                    Log.d(TAG, "Discover " + msg.arg2 + " " + id
+                                            + info.getServiceType());
+                                }
+                                storeRequestMap(clientId, id, clientInfo, msg.what);
+                                clientInfo.onDiscoverServicesStarted(clientId, info);
+                            } else {
+                                stopServiceDiscovery(id);
+                                clientInfo.onDiscoverServicesFailed(clientId,
+                                        NsdManager.FAILURE_INTERNAL_ERROR);
+                            }
                         }
                         break;
                     case NsdManager.STOP_DISCOVERY:
@@ -359,12 +532,25 @@
                                     clientId, NsdManager.FAILURE_INTERNAL_ERROR);
                             break;
                         }
-                        removeRequestMap(clientId, id, clientInfo);
-                        if (stopServiceDiscovery(id)) {
+                        if (mMdnsDiscoveryManager != null) {
+                            final MdnsListener listener = clientInfo.mListeners.get(clientId);
+                            if (listener == null) {
+                                clientInfo.onStopDiscoveryFailed(
+                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                break;
+                            }
+                            mMdnsDiscoveryManager.unregisterListener(
+                                    listener.getListenedServiceType(), listener);
+                            removeListenerMap(clientId, id, clientInfo);
                             clientInfo.onStopDiscoverySucceeded(clientId);
                         } else {
-                            clientInfo.onStopDiscoveryFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            removeRequestMap(clientId, id, clientInfo);
+                            if (stopServiceDiscovery(id)) {
+                                clientInfo.onStopDiscoverySucceeded(clientId);
+                            } else {
+                                clientInfo.onStopDiscoveryFailed(
+                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            }
                         }
                         break;
                     case NsdManager.REGISTER_SERVICE:
@@ -982,6 +1168,9 @@
         /* A map from client id to the type of the request we had received */
         private final SparseIntArray mClientRequests = new SparseIntArray();
 
+        /* A map from client id to the MdnsListener */
+        private final SparseArray<MdnsListener> mListeners = new SparseArray<>();
+
         // The target SDK of this client < Build.VERSION_CODES.S
         private boolean mIsLegacy = false;
 
@@ -1043,6 +1232,15 @@
             mClientRequests.clear();
         }
 
+        void unregisterAllListeners() {
+            for (int i = 0; i < mListeners.size(); i++) {
+                final MdnsListener listener = mListeners.valueAt(i);
+                mMdnsDiscoveryManager.unregisterListener(
+                        listener.getListenedServiceType(), listener);
+            }
+            mListeners.clear();
+        }
+
         // mClientIds is a sparse array of listener id -> mDnsClient id.  For a given mDnsClient id,
         // return the corresponding listener id.  mDnsClient id is also called a global id.
         private int getClientId(final int globalId) {