Allow Advertiser, DiscoveryManager runtime toggle

Allow toggling MdnsAdvertiser and MdnsDiscoveryManager at runtime, by
always creating them in NsdService constructor, but only using them when
the flag is on when starting discovery, resolve or registration.

When stopping, based on the type of the stored request, stop the
corresponding backend.

Bug: 265891278
Test: atest NsdServiceTest
Change-Id: I7cb2f9fe9e1ed3dc77616689a8e3ffa00f5bc269
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 49c6ef0..4ad39e1 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -118,13 +118,15 @@
     private final NsdStateMachine mNsdStateMachine;
     private final MDnsManager mMDnsManager;
     private final MDnsEventCallback mMDnsEventCallback;
-    @Nullable
+    @NonNull
+    private final Dependencies mDeps;
+    @NonNull
     private final MdnsMultinetworkSocketClient mMdnsSocketClient;
-    @Nullable
+    @NonNull
     private final MdnsDiscoveryManager mMdnsDiscoveryManager;
-    @Nullable
+    @NonNull
     private final MdnsSocketProvider mMdnsSocketProvider;
-    @Nullable
+    @NonNull
     private final MdnsAdvertiser mAdvertiser;
     // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
     // state machine thread. If change this outside state machine, it will need to introduce
@@ -311,21 +313,14 @@
             mIsMonitoringSocketsStarted = true;
         }
 
-        private void maybeStopMonitoringSockets() {
-            if (!mIsMonitoringSocketsStarted) {
-                if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
-                return;
-            }
+        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
+            if (!mIsMonitoringSocketsStarted) return;
+            if (isAnyRequestActive()) return;
+
             mMdnsSocketProvider.stopMonitoringSockets();
             mIsMonitoringSocketsStarted = false;
         }
 
-        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
-            if (!isAnyRequestActive()) {
-                maybeStopMonitoringSockets();
-            }
-        }
-
         NsdStateMachine(String name, Handler handler) {
             super(name, handler);
             addState(mDefaultState);
@@ -362,9 +357,7 @@
                                 mLegacyClientCount -= 1;
                             }
                         }
-                        if (mMdnsDiscoveryManager != null || mAdvertiser != null) {
-                            maybeStopMonitoringSocketsIfNoActiveRequest();
-                        }
+                        maybeStopMonitoringSocketsIfNoActiveRequest();
                         maybeScheduleStop();
                         break;
                     case NsdManager.DISCOVER_SERVICES:
@@ -579,7 +572,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onDiscoverServicesFailed(clientId,
@@ -634,6 +627,9 @@
                             break;
                         }
                         id = request.mGlobalId;
+                        // Note isMdnsDiscoveryManagerEnabled may have changed to false at this
+                        // point, so this needs to check the type of the original request to
+                        // unregister instead of looking at the flag value.
                         if (request instanceof DiscoveryManagerRequest) {
                             final MdnsListener listener =
                                     ((DiscoveryManagerRequest) request).mListener;
@@ -671,7 +667,7 @@
                         }
 
                         id = getUniqueId();
-                        if (mAdvertiser != null) {
+                        if (mDeps.isMdnsAdvertiserEnabled(mContext)) {
                             final NsdServiceInfo serviceInfo = args.serviceInfo;
                             final String serviceType = serviceInfo.getServiceType();
                             final String registerServiceType = constructServiceType(serviceType);
@@ -722,7 +718,10 @@
                         id = request.mGlobalId;
                         removeRequestMap(clientId, id, clientInfo);
 
-                        if (mAdvertiser != null) {
+                        // Note isMdnsAdvertiserEnabled may have changed to false at this point,
+                        // so this needs to check the type of the original request to unregister
+                        // instead of looking at the flag value.
+                        if (request instanceof AdvertiserClientRequest) {
                             mAdvertiser.removeService(id);
                             clientInfo.onUnregisterServiceSucceeded(clientId);
                         } else {
@@ -749,7 +748,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onResolveServiceFailed(clientId,
@@ -1241,32 +1240,16 @@
         mNsdStateMachine.start();
         mMDnsManager = ctx.getSystemService(MDnsManager.class);
         mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
+        mDeps = deps;
 
-        final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx);
-        final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx);
-        if (discoveryManagerEnabled || advertiserEnabled) {
-            mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
-        } else {
-            mMdnsSocketProvider = null;
-        }
-
-        if (discoveryManagerEnabled) {
-            mMdnsSocketClient =
-                    new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
-            mMdnsDiscoveryManager =
-                    deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
-            handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
-        } else {
-            mMdnsSocketClient = null;
-            mMdnsDiscoveryManager = null;
-        }
-
-        if (advertiserEnabled) {
-            mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
-                    new AdvertiserCallback());
-        } else {
-            mAdvertiser = null;
-        }
+        mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
+        mMdnsSocketClient =
+                new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
+        mMdnsDiscoveryManager =
+                deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
+        handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
+        mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
+                new AdvertiserCallback());
     }
 
     /**
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 98a8ed2..a2c4b9b 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -45,6 +45,7 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import android.compat.testing.PlatformCompatChangeRule;
@@ -170,6 +171,9 @@
         doReturn(true).when(mMockMDnsM).resolve(
                 anyInt(), anyString(), anyString(), anyString(), anyInt());
         doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
+        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
+        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
+        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
 
         mService = makeService();
     }
@@ -824,40 +828,50 @@
                 client.unregisterServiceInfoCallback(serviceInfoCallback));
     }
 
-    private void makeServiceWithMdnsDiscoveryManagerEnabled() {
+    private void setMdnsDiscoveryManagerEnabled() {
         doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
-        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
-    private void makeServiceWithMdnsAdvertiserEnabled() {
+    private void setMdnsAdvertiserEnabled() {
         doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
-        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
     @Test
     public void testMdnsDiscoveryManagerFeature() {
         // Create NsdService w/o feature enabled.
-        connectClient(mService);
-        verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps, never()).makeMdnsSocketProvider(any(), any());
+        final NsdManager client = connectClient(mService);
+        final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature);
+        waitForIdle();
 
-        // Create NsdService again w/ feature enabled.
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt());
+        verifyNoMoreInteractions(mDiscoveryManager);
+
+        setMdnsDiscoveryManagerEnabled();
+        final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature);
+        waitForIdle();
+
+        final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
+                ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
+        verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
+                listenerCaptor.capture(), any());
+
+        client.stopServiceDiscovery(discListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+
+        client.stopServiceDiscovery(discListenerWithFeature);
+        waitForIdle();
+        verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain,
+                listenerCaptor.getValue());
     }
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -922,7 +936,7 @@
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -951,7 +965,7 @@
 
     @Test
     public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final ResolveListener resolveListener = mock(ResolveListener.class);
@@ -1005,8 +1019,43 @@
     }
 
     @Test
+    public void testMdnsAdvertiserFeatureFlagging() {
+        // Create NsdService w/o feature enabled.
+        final NsdManager client = connectClient(mService);
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(),
+                any(), anyInt());
+        verifyNoMoreInteractions(mAdvertiser);
+
+        setMdnsAdvertiserEnabled();
+        final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mAdvertiser).addService(serviceIdCaptor.capture(),
+                argThat(info -> matches(info, regInfo)));
+
+        client.unregisterService(regListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+        verify(mAdvertiser, never()).removeService(anyInt());
+
+        client.unregisterService(regListenerWithFeature);
+        waitForIdle();
+        verify(mAdvertiser).removeService(serviceIdCaptor.getValue());
+    }
+
+    @Test
     public void testAdvertiseWithMdnsAdvertiser() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1045,7 +1094,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1070,7 +1119,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);