Use recording callbacks in NsdManagerTest

Refactor the test to use recording callbacks based on ArrayTrackRecord,
which allow removing the test's own logic to poll for events.

Bug: 190249673
Test: atest NsdManagerTest --rerun-until-failure 20
Change-Id: Iad0b0d52271b13954c0193b3b9d4307349a39443
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 8daf720..9307c27 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -15,15 +15,29 @@
  */
 package android.net.cts
 
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StartDiscoveryFailed
+import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StopDiscoveryFailed
+import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.RegistrationFailed
+import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
+import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
+import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed
+import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed
+import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved
 import android.net.nsd.NsdManager
 import android.net.nsd.NsdManager.DiscoveryListener
+import android.net.nsd.NsdManager.RegistrationListener
 import android.net.nsd.NsdManager.ResolveListener
 import android.net.nsd.NsdServiceInfo
-import android.os.SystemClock
 import android.platform.test.annotations.AppModeFull
 import android.util.Log
 import androidx.test.platform.app.InstrumentationRegistry
 import androidx.test.runner.AndroidJUnit4
+import com.android.net.module.util.ArrayTrackRecord
+import com.android.net.module.util.TrackRecord
 import org.junit.Assert.assertArrayEquals
 import org.junit.Assert.assertTrue
 import org.junit.Test
@@ -35,11 +49,12 @@
 import kotlin.test.assertFailsWith
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
+import kotlin.test.assertTrue
 import kotlin.test.fail
 
 private const val TAG = "NsdManagerTest"
 private const val SERVICE_TYPE = "_nmt._tcp"
-private const val TIMEOUT = 2000
+private const val TIMEOUT_MS = 2000L
 private const val DBG = false
 
 @AppModeFull(reason = "Socket cannot bind in instant app mode")
@@ -49,160 +64,128 @@
     private val nsdManager by lazy { context.getSystemService(NsdManager::class.java) }
     private val serviceName = "NsdTest%04d".format(Random().nextInt(1000))
 
-    private val registrationListener = object : NsdManager.RegistrationListener {
-        override fun onRegistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
-            setEvent("onRegistrationFailed", errorCode)
-        }
+    private interface NsdEvent
+    private open class NsdRecord<T : NsdEvent> private constructor(
+        private val history: ArrayTrackRecord<T>
+    ) : TrackRecord<T> by history {
+        constructor() : this(ArrayTrackRecord())
 
-        override fun onUnregistrationFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
-            setEvent("onUnregistrationFailed", errorCode)
-        }
+        val nextEvents = history.newReadHead()
 
-        override fun onServiceRegistered(serviceInfo: NsdServiceInfo) {
-            setEvent("onServiceRegistered", serviceInfo)
-        }
+        inline fun <reified V : NsdEvent> expectCallbackEventually(
+            crossinline predicate: (V) -> Boolean = { true }
+        ): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V?
+                ?: fail("Callback for ${V::class.java.simpleName} not seen after $TIMEOUT_MS ms")
 
-        override fun onServiceUnregistered(serviceInfo: NsdServiceInfo) {
-            setEvent("onServiceUnregistered", serviceInfo)
+        inline fun <reified V : NsdEvent> expectCallback(): V {
+            val nextEvent = nextEvents.poll(TIMEOUT_MS)
+            assertNotNull(nextEvent, "No callback received after $TIMEOUT_MS ms")
+            assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
+                    nextEvent.javaClass.simpleName)
+            return nextEvent
         }
     }
 
-    private val discoveryListener = object : DiscoveryListener {
-        override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {
-            setEvent("onStartDiscoveryFailed", errorCode)
+    private class NsdRegistrationRecord : RegistrationListener,
+            NsdRecord<NsdRegistrationRecord.RegistrationEvent>() {
+        sealed class RegistrationEvent : NsdEvent {
+            abstract val serviceInfo: NsdServiceInfo
+
+            data class RegistrationFailed(
+                override val serviceInfo: NsdServiceInfo,
+                val errorCode: Int
+            ) : RegistrationEvent()
+
+            data class UnregistrationFailed(
+                override val serviceInfo: NsdServiceInfo,
+                val errorCode: Int
+            ) : RegistrationEvent()
+
+            data class ServiceRegistered(override val serviceInfo: NsdServiceInfo)
+                : RegistrationEvent()
+            data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo)
+                : RegistrationEvent()
         }
 
-        override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {
-            setEvent("onStopDiscoveryFailed", errorCode)
+        override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
+            add(RegistrationFailed(si, err))
+        }
+
+        override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
+            add(UnregistrationFailed(si, err))
+        }
+
+        override fun onServiceRegistered(si: NsdServiceInfo) {
+            add(ServiceRegistered(si))
+        }
+
+        override fun onServiceUnregistered(si: NsdServiceInfo) {
+            add(ServiceUnregistered(si))
+        }
+    }
+
+    private class NsdDiscoveryRecord : DiscoveryListener,
+            NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>() {
+        sealed class DiscoveryEvent : NsdEvent {
+            data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int)
+                : DiscoveryEvent()
+
+            data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int)
+                : DiscoveryEvent()
+
+            data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
+            data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
+            data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
+            data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
+        }
+
+        override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
+            add(StartDiscoveryFailed(serviceType, err))
+        }
+
+        override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
+            add(StopDiscoveryFailed(serviceType, err))
         }
 
         override fun onDiscoveryStarted(serviceType: String) {
-            val info = NsdServiceInfo()
-            info.serviceType = serviceType
-            setEvent("onDiscoveryStarted", info)
+            add(DiscoveryStarted(serviceType))
         }
 
         override fun onDiscoveryStopped(serviceType: String) {
-            val info = NsdServiceInfo()
-            info.serviceType = serviceType
-            setEvent("onDiscoveryStopped", info)
+            add(DiscoveryStopped(serviceType))
         }
 
-        override fun onServiceFound(serviceInfo: NsdServiceInfo) {
-            setEvent("onServiceFound", serviceInfo)
+        override fun onServiceFound(si: NsdServiceInfo) {
+            add(ServiceFound(si))
         }
 
-        override fun onServiceLost(serviceInfo: NsdServiceInfo) {
-            setEvent("onServiceLost", serviceInfo)
+        override fun onServiceLost(si: NsdServiceInfo) {
+            add(ServiceLost(si))
+        }
+
+        fun waitForServiceDiscovered(serviceName: String): NsdServiceInfo {
+            return expectCallbackEventually<ServiceFound> {
+                it.serviceInfo.serviceName == serviceName
+            }.serviceInfo
         }
     }
 
-    private inner class TestResolveListener : ResolveListener {
-        var resolvedService: NsdServiceInfo? = null
-        override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {
-            setEvent("onResolveFailed", errorCode)
+    private class NsdResolveRecord : ResolveListener,
+            NsdRecord<NsdResolveRecord.ResolveEvent>() {
+        sealed class ResolveEvent : NsdEvent {
+            data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int)
+                : ResolveEvent()
+
+            data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
         }
 
-        override fun onServiceResolved(serviceInfo: NsdServiceInfo) {
-            resolvedService = serviceInfo
-            setEvent("onServiceResolved", serviceInfo)
-        }
-    }
-
-    private class EventData {
-        constructor(callbackName: String, info: NsdServiceInfo?) {
-            this.callbackName = callbackName
-            succeeded = true
-            errorCode = 0
-            this.info = info
+        override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
+            add(ResolveFailed(si, err))
         }
 
-        constructor(callbackName: String, errorCode: Int) {
-            this.callbackName = callbackName
-            succeeded = false
-            this.errorCode = errorCode
-            info = null
+        override fun onServiceResolved(si: NsdServiceInfo) {
+            add(ServiceResolved(si))
         }
-
-        val callbackName: String
-        val succeeded: Boolean
-        private val errorCode: Int
-        val info: NsdServiceInfo?
-    }
-
-    private val eventCache = ArrayList<EventData>()
-    private fun setEvent(callbackName: String, errorCode: Int) {
-        if (DBG) Log.d(TAG, "$callbackName failed with $errorCode")
-        val eventData = EventData(callbackName, errorCode)
-        synchronized(eventCache) {
-            eventCache.add(eventData)
-            eventCache.notify()
-        }
-    }
-
-    private fun setEvent(callbackName: String, info: NsdServiceInfo) {
-        if (DBG) Log.d(TAG, "Received event " + callbackName + " for " + info.serviceName)
-        val eventData = EventData(callbackName, info)
-        synchronized(eventCache) {
-            eventCache.add(eventData)
-            eventCache.notify()
-        }
-    }
-
-    fun clearEventCache() {
-        synchronized(eventCache) { eventCache.clear() }
-    }
-
-    fun eventCacheSize(): Int {
-        synchronized(eventCache) { return eventCache.size }
-    }
-
-    private var waitId = 0
-    private fun waitForCallback(callbackName: String): EventData? {
-        synchronized(eventCache) {
-            waitId++
-            if (DBG) Log.d(TAG, "Waiting for $callbackName, id=$waitId")
-            val startTime = SystemClock.uptimeMillis()
-            var elapsedTime = 0L
-            while (elapsedTime < TIMEOUT) {
-                // first check if we've received that event
-                eventCache.find { it.callbackName == callbackName }?.let {
-                    if (DBG) Log.d(TAG, "exiting wait id=$waitId")
-                    return it
-                }
-
-                // Not yet received, just wait
-                try {
-                    eventCache.wait(TIMEOUT - elapsedTime)
-                } catch (e: InterruptedException) {
-                    return null
-                }
-                elapsedTime = SystemClock.uptimeMillis() - startTime
-            }
-            // we exited the loop because of TIMEOUT; fail the call
-            if (DBG) Log.d(TAG, "timed out waiting id=$waitId")
-            return null
-        }
-    }
-
-    private fun waitForNewEvents(): EventData? {
-        if (DBG) Log.d(TAG, "Waiting for a bit, id=$waitId")
-        val startTime = SystemClock.uptimeMillis()
-        var elapsedTime = 0L
-        synchronized(eventCache) {
-            val index = eventCache.size
-            while (elapsedTime < TIMEOUT) {
-                // first check if we've received that event
-                if (index < eventCache.size) {
-                    return eventCache[index]
-                }
-
-                // Not yet received, just wait
-                eventCache.wait(TIMEOUT - elapsedTime)
-                elapsedTime = SystemClock.uptimeMillis() - startTime
-            }
-        }
-        return null
     }
 
     @Test
@@ -210,38 +193,30 @@
         val si = NsdServiceInfo()
         si.serviceType = SERVICE_TYPE
         si.serviceName = serviceName
+        // Test binary data with various bytes
         val testByteArray = byteArrayOf(-128, 127, 2, 1, 0, 1, 2)
+        // Test string data with 256 characters (25 blocks of 10 characters + 6)
         val string256 = "1_________2_________3_________4_________5_________6_________" +
                 "7_________8_________9_________10________11________12________13________" +
                 "14________15________16________17________18________19________20________" +
                 "21________22________23________24________25________123456"
 
         // Illegal attributes
-        assertFailsWith<IllegalArgumentException>("Could set null key") {
-            si.setAttribute(null, null as String?)
-        }
-        assertFailsWith<IllegalArgumentException>("Could set empty key") {
-            si.setAttribute("", null as String?)
-        }
-        assertFailsWith<IllegalArgumentException>("Could set key with 255 characters") {
-            si.setAttribute(string256, null as String?)
-        }
-        assertFailsWith<IllegalArgumentException>(
-                "Could set key+value combination with more than 255 characters") {
-            si.setAttribute("key", string256.substring(3))
-        }
-        assertFailsWith<IllegalArgumentException>(
-                "Could set key+value combination with 255 characters") {
-            si.setAttribute("key", string256.substring(4))
-        }
-        assertFailsWith<IllegalArgumentException>("Could set key with invalid character") {
-            si.setAttribute("\u0019", null as String?)
-        }
-        assertFailsWith<IllegalArgumentException>("Could set key with invalid character") {
-            si.setAttribute("=", null as String?)
-        }
-        assertFailsWith<IllegalArgumentException>("Could set key with invalid character") {
-            si.setAttribute("\u007f", null as String?)
+        listOf(
+                Triple(null, null, "null key"),
+                Triple("", null, "empty key"),
+                Triple(string256, null, "key with 256 characters"),
+                Triple("key", string256.substring(3),
+                        "key+value combination with more than 255 characters"),
+                Triple("key", string256.substring(4), "key+value combination with 255 characters"),
+                Triple("\u0019", null, "key with invalid character"),
+                Triple("=", null, "key with invalid character"),
+                Triple("\u007f", null, "key with invalid character")
+        ).forEach {
+            assertFailsWith<IllegalArgumentException>(
+                    "Setting invalid ${it.third} unexpectedly succeeded") {
+                si.setAttribute(it.first, it.second)
+            }
         }
 
         // Allowed attributes
@@ -257,28 +232,18 @@
         val localPort = socket.localPort
         si.port = localPort
         if (DBG) Log.d(TAG, "Port = $localPort")
-        clearEventCache()
 
-        val registeredName = registerService(si)
+        val registrationRecord = NsdRegistrationRecord()
+        val registeredInfo = registerService(registrationRecord, si)
 
-        assertEquals(1, eventCacheSize())
-        clearEventCache()
-        nsdManager.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, discoveryListener)
+        val discoveryRecord = NsdDiscoveryRecord()
+        nsdManager.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
 
         // Expect discovery started
-        var lastEvent = waitForCallback("onDiscoveryStarted")
-        assertNotNull(lastEvent)
-        assertTrue(lastEvent.succeeded)
-
-        // Remove this event, so accounting becomes easier later
-        synchronized(eventCache) { eventCache.remove(lastEvent) }
+        discoveryRecord.expectCallback<DiscoveryStarted>()
 
         // Expect a service record to be discovered
-        val foundInfo = waitForServiceDiscovered(registeredName)
-
-        // We've removed all serviceFound events, and we've removed the discoveryStarted
-        // event as well, so now the event cache should be empty!
-        assertEquals(0, eventCacheSize())
+        val foundInfo = discoveryRecord.waitForServiceDiscovered(registeredInfo.serviceName)
 
         val resolvedService = resolveService(foundInfo)
 
@@ -296,119 +261,61 @@
         assertNull(resolvedService.attributes["nullBinaryDataAttr"])
         assertTrue(resolvedService.attributes.containsKey("emptyBinaryDataAttr"))
         assertNull(resolvedService.attributes["emptyBinaryDataAttr"])
-        if (DBG) Log.d(TAG, "id = $waitId: Port = ${lastEvent.info?.port}")
         assertEquals(localPort, resolvedService.port)
-        assertEquals(1, eventCacheSize())
-        checkForAdditionalEvents()
-        clearEventCache()
 
         // Unregister the service
-        nsdManager.unregisterService(registrationListener)
-        lastEvent = waitForCallback("onServiceUnregistered")
-        assertNotNull(lastEvent)
-        assertTrue(lastEvent.succeeded)
+        nsdManager.unregisterService(registrationRecord)
+        registrationRecord.expectCallback<ServiceUnregistered>()
 
         // Expect a callback for service lost
-        lastEvent = waitForCallback("onServiceLost")
-        assertNotNull(lastEvent)
-        assertEquals(registeredName, lastEvent.info?.serviceName)
+        discoveryRecord.expectCallbackEventually<ServiceLost> {
+            it.serviceInfo.serviceName == serviceName
+        }
 
-        // Register service again to see if we discover it
-        checkForAdditionalEvents()
-        clearEventCache()
+        // Register service again to see if NsdManager can discover it
         val si2 = NsdServiceInfo()
         si2.serviceType = SERVICE_TYPE
         si2.serviceName = serviceName
         si2.port = localPort
-        val registeredName2 = registerService(si2)
+        val registrationRecord2 = NsdRegistrationRecord()
+        val registeredInfo2 = registerService(registrationRecord2, si2)
 
-        // Expect a record to be discovered
         // Expect a service record to be discovered (and filter the ones
         // that are unrelated to this test)
-        val foundInfo2 = waitForServiceDiscovered(registeredName2)
+        val foundInfo2 = discoveryRecord.waitForServiceDiscovered(registeredInfo2.serviceName)
 
         // Resolve the service
-        clearEventCache()
         val resolvedService2 = resolveService(foundInfo2)
 
-        // Check that we don't have any TXT records
+        // Check that the resolved service doesn't have any TXT records
         assertEquals(0, resolvedService2.attributes.size)
-        checkForAdditionalEvents()
-        clearEventCache()
-        nsdManager.stopServiceDiscovery(discoveryListener)
-        lastEvent = waitForCallback("onDiscoveryStopped")
-        assertNotNull(lastEvent)
-        assertTrue(lastEvent.succeeded)
-        checkCacheSize(1)
-        checkForAdditionalEvents()
-        clearEventCache()
-        nsdManager.unregisterService(registrationListener)
-        lastEvent = waitForCallback("onServiceUnregistered")
-        assertNotNull(lastEvent)
-        assertTrue(lastEvent.succeeded)
-        checkCacheSize(1)
+
+        nsdManager.stopServiceDiscovery(discoveryRecord)
+
+        discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
+
+        nsdManager.unregisterService(registrationRecord2)
+        registrationRecord2.expectCallback<ServiceUnregistered>()
     }
 
     /**
-     * Register a service and return its registered name.
+     * Register a service and return its registration record.
      */
-    private fun registerService(si: NsdServiceInfo): String {
-        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, registrationListener)
-
+    private fun registerService(record: NsdRegistrationRecord, si: NsdServiceInfo): NsdServiceInfo {
+        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, record)
         // We may not always get the name that we tried to register;
         // This events tells us the name that was registered.
-        val cb = waitForCallback("onServiceRegistered")
-        assertNotNull(cb)
-        assertTrue(cb.succeeded)
-        return cb.info?.serviceName ?: fail("Missing event info")
-    }
-
-    private fun waitForServiceDiscovered(serviceName: String): NsdServiceInfo {
-        var foundInfo: NsdServiceInfo? = null
-        repeat(32) {
-            val event = waitForCallback("onServiceFound") ?: return@repeat
-            assertTrue(event.succeeded)
-            if (DBG) Log.d(TAG, "id = $waitId: ServiceName = ${event.info?.serviceName}")
-            if (event.info?.serviceName == serviceName) {
-                // Save it, as it will get overwritten with new serviceFound events
-                foundInfo = event.info
-            }
-
-            // Remove this event from the event cache, so it won't be found by subsequent
-            // calls to waitForCallback
-            synchronized(eventCache) { eventCache.remove(event) }
-        }
-        return foundInfo ?: fail("Service not discovered")
+        val cb = record.expectCallback<ServiceRegistered>()
+        return cb.serviceInfo
     }
 
     private fun resolveService(discoveredInfo: NsdServiceInfo): NsdServiceInfo {
-        val resolveListener = TestResolveListener()
-        nsdManager.resolveService(discoveredInfo, resolveListener)
-        val resolvedCb = waitForCallback("onServiceResolved")
-        assertNotNull(resolvedCb)
-        assertTrue(resolvedCb.succeeded)
-        if (DBG) Log.d(TAG, "id = $waitId: ServiceName = ${resolvedCb.info?.serviceName}")
-        assertEquals(discoveredInfo.serviceName, resolvedCb.info?.serviceName)
+        val record = NsdResolveRecord()
+        nsdManager.resolveService(discoveredInfo, record)
+        val resolvedCb = record.expectCallback<ServiceResolved>()
+        assertEquals(discoveredInfo.serviceName, resolvedCb.serviceInfo.serviceName)
 
-        return resolveListener.resolvedService ?: fail("Missing resolved service")
-    }
-
-    fun checkCacheSize(size: Int) {
-        synchronized(eventCache) {
-            if (size != eventCache.size) {
-                fail("Expected size $size, found event list [${
-                    eventCache.joinToString(", ") {
-                        "eventName: ${it.callbackName}, serviceName ${it.info?.serviceName}"
-                    }
-                }]")
-            }
-        }
-    }
-
-    fun checkForAdditionalEvents(): Boolean {
-        val e = waitForNewEvents() ?: return true
-        Log.d(TAG, "ignoring unexpected event ${e.callbackName} (${e.info?.serviceName})")
-        return false
+        return resolvedCb.serviceInfo
     }
 }
 
@@ -416,7 +323,3 @@
     if (this == null) return ""
     return String(this, StandardCharsets.UTF_8)
 }
-
-// TODO: migrate legacy java-style implementation to newer utils like RecorderCallback
-private fun Any.wait(timeout: Long) = (this as Object).wait(timeout)
-private fun Any.notify() = (this as Object).notify()
\ No newline at end of file