Merge "Register service on the specified network" into tm-dev
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 4086e4e..ea57bac 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -738,7 +738,13 @@
         String type = service.getServiceType();
         int port = service.getPort();
         byte[] textRecord = service.getTxtRecord();
-        return mMDnsManager.registerService(regId, name, type, port, textRecord, IFACE_IDX_ANY);
+        final Network network = service.getNetwork();
+        final int registerInterface = getNetworkInterfaceIndex(network);
+        if (network != null && registerInterface == IFACE_IDX_ANY) {
+            Log.e(TAG, "Interface to register service on not found");
+            return false;
+        }
+        return mMDnsManager.registerService(regId, name, type, port, textRecord, registerInterface);
     }
 
     private boolean unregisterService(int regId) {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index b139a9b..7b0451f 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -82,6 +82,7 @@
 private const val TAG = "NsdManagerTest"
 private const val SERVICE_TYPE = "_nmt._tcp"
 private const val TIMEOUT_MS = 2000L
+private const val NO_CALLBACK_TIMEOUT_MS = 200L
 private const val DBG = false
 
 private val nsdShim = NsdShimImpl.newInstance()
@@ -136,6 +137,11 @@
                     nextEvent.javaClass.simpleName)
             return nextEvent
         }
+
+        inline fun assertNoCallback(timeoutMs: Long = NO_CALLBACK_TIMEOUT_MS) {
+            val cb = nextEvents.poll(timeoutMs)
+            assertNull(cb, "Expected no callback but got $cb")
+        }
     }
 
     private class NsdRegistrationRecord : RegistrationListener,
@@ -556,6 +562,55 @@
         }
     }
 
+    @Test
+    fun testNsdManager_RegisterOnNetwork() {
+        // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
+        assumeTrue(ConstantsShim.VERSION > SC_V2)
+
+        val si = NsdServiceInfo()
+        si.serviceType = SERVICE_TYPE
+        si.serviceName = this.serviceName
+        si.network = testNetwork1.network
+        si.port = 12345 // Test won't try to connect so port does not matter
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        registerService(registrationRecord, si)
+        val discoveryRecord = NsdDiscoveryRecord()
+        val discoveryRecord2 = NsdDiscoveryRecord()
+        val discoveryRecord3 = NsdDiscoveryRecord()
+
+        tryTest {
+            // Discover service on testNetwork1.
+            nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, Executor { it.run() }, discoveryRecord)
+            // Expect that service is found on testNetwork1
+            val foundInfo = discoveryRecord.waitForServiceDiscovered(
+                serviceName, testNetwork1.network)
+            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
+
+            // Discover service on testNetwork2.
+            nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork2.network, Executor { it.run() }, discoveryRecord2)
+            // Expect that discovery is started then no other callbacks.
+            discoveryRecord2.expectCallback<DiscoveryStarted>()
+            discoveryRecord2.assertNoCallback()
+
+            // Discover service on all networks (not specify any network).
+            nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
+                null as Network? /* network */, Executor { it.run() }, discoveryRecord3)
+            // Expect that service is found on testNetwork1
+            val foundInfo3 = discoveryRecord3.waitForServiceDiscovered(
+                    serviceName, testNetwork1.network)
+            assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo3))
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord2)
+            discoveryRecord2.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord)
+        }
+    }
+
     /**
      * Register a service and return its registration record.
      */