Merge changes from topic "cherrypicker-L34800000954905788:N57900001270872696" into tm-dev

* changes:
  Ignore new test until prebuilts are updated.
  Add deny firewall chain for OEM
diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index 33b44c8..f19bf4a 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -312,9 +312,12 @@
             @Override
             public void onAvailable(@NonNull Network network) {
                 final DelegatingDiscoveryListener wrappedListener = new DelegatingDiscoveryListener(
-                        network, mBaseListener);
+                        network, mBaseListener, mBaseExecutor);
                 mPerNetworkListeners.put(network, wrappedListener);
-                discoverServices(mServiceType, mProtocolType, network, mBaseExecutor,
+                // Run discovery callbacks inline on the service handler thread, which is the
+                // same thread used by this NetworkCallback, but DelegatingDiscoveryListener will
+                // use the base executor to run the wrapped callbacks.
+                discoverServices(mServiceType, mProtocolType, network, Runnable::run,
                         wrappedListener);
             }
 
@@ -334,7 +337,8 @@
         public void start(@NonNull NetworkRequest request) {
             final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
             cm.registerNetworkCallback(request, mNetworkCb, mHandler);
-            mHandler.post(() -> mBaseListener.onDiscoveryStarted(mServiceType));
+            mHandler.post(() -> mBaseExecutor.execute(() ->
+                    mBaseListener.onDiscoveryStarted(mServiceType)));
         }
 
         /**
@@ -351,7 +355,7 @@
                 final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
                 cm.unregisterNetworkCallback(mNetworkCb);
                 if (mPerNetworkListeners.size() == 0) {
-                    mBaseListener.onDiscoveryStopped(mServiceType);
+                    mBaseExecutor.execute(() -> mBaseListener.onDiscoveryStopped(mServiceType));
                     return;
                 }
                 for (int i = 0; i < mPerNetworkListeners.size(); i++) {
@@ -399,14 +403,23 @@
             }
         }
 
+        /**
+         * A listener wrapping calls to an app-provided listener, while keeping track of found
+         * services, so they can all be reported lost when the underlying network is lost.
+         *
+         * This should be registered to run on the service handler.
+         */
         private class DelegatingDiscoveryListener implements DiscoveryListener {
             private final Network mNetwork;
             private final DiscoveryListener mWrapped;
+            private final Executor mWrappedExecutor;
             private final ArraySet<TrackedNsdInfo> mFoundInfo = new ArraySet<>();
 
-            private DelegatingDiscoveryListener(Network network, DiscoveryListener listener) {
+            private DelegatingDiscoveryListener(Network network, DiscoveryListener listener,
+                    Executor executor) {
                 mNetwork = network;
                 mWrapped = listener;
+                mWrappedExecutor = executor;
             }
 
             void notifyAllServicesLost() {
@@ -415,7 +428,7 @@
                     final NsdServiceInfo serviceInfo = new NsdServiceInfo(
                             trackedInfo.mServiceName, trackedInfo.mServiceType);
                     serviceInfo.setNetwork(mNetwork);
-                    mWrapped.onServiceLost(serviceInfo);
+                    mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo));
                 }
             }
 
@@ -444,7 +457,7 @@
                     // Do not report onStopDiscoveryFailed when some underlying listeners failed:
                     // this does not mean that all listeners did, and onStopDiscoveryFailed is not
                     // actionable anyway. Just report that discovery stopped.
-                    mWrapped.onDiscoveryStopped(serviceType);
+                    mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType));
                 }
             }
 
@@ -452,20 +465,20 @@
             public void onDiscoveryStopped(String serviceType) {
                 mPerNetworkListeners.remove(mNetwork);
                 if (mStopRequested && mPerNetworkListeners.size() == 0) {
-                    mWrapped.onDiscoveryStopped(serviceType);
+                    mWrappedExecutor.execute(() -> mWrapped.onDiscoveryStopped(serviceType));
                 }
             }
 
             @Override
             public void onServiceFound(NsdServiceInfo serviceInfo) {
                 mFoundInfo.add(new TrackedNsdInfo(serviceInfo));
-                mWrapped.onServiceFound(serviceInfo);
+                mWrappedExecutor.execute(() -> mWrapped.onServiceFound(serviceInfo));
             }
 
             @Override
             public void onServiceLost(NsdServiceInfo serviceInfo) {
                 mFoundInfo.remove(new TrackedNsdInfo(serviceInfo));
-                mWrapped.onServiceLost(serviceInfo);
+                mWrappedExecutor.execute(() -> mWrapped.onServiceLost(serviceInfo));
             }
         }
     }
@@ -648,8 +661,12 @@
 
         @Override
         public void handleMessage(Message message) {
+            // Do not use message in the executor lambdas, as it will be recycled once this method
+            // returns. Keep references to its content instead.
             final int what = message.what;
+            final int errorCode = message.arg1;
             final int key = message.arg2;
+            final Object obj = message.obj;
             final Object listener;
             final NsdServiceInfo ns;
             final Executor executor;
@@ -659,7 +676,7 @@
                 executor = mExecutorMap.get(key);
             }
             if (listener == null) {
-                Log.d(TAG, "Stale key " + message.arg2);
+                Log.d(TAG, "Stale key " + key);
                 return;
             }
             if (DBG) {
@@ -667,28 +684,28 @@
             }
             switch (what) {
                 case DISCOVER_SERVICES_STARTED:
-                    final String s = getNsdServiceInfoType((NsdServiceInfo) message.obj);
+                    final String s = getNsdServiceInfoType((NsdServiceInfo) obj);
                     executor.execute(() -> ((DiscoveryListener) listener).onDiscoveryStarted(s));
                     break;
                 case DISCOVER_SERVICES_FAILED:
                     removeListener(key);
                     executor.execute(() -> ((DiscoveryListener) listener).onStartDiscoveryFailed(
-                            getNsdServiceInfoType(ns), message.arg1));
+                            getNsdServiceInfoType(ns), errorCode));
                     break;
                 case SERVICE_FOUND:
                     executor.execute(() -> ((DiscoveryListener) listener).onServiceFound(
-                            (NsdServiceInfo) message.obj));
+                            (NsdServiceInfo) obj));
                     break;
                 case SERVICE_LOST:
                     executor.execute(() -> ((DiscoveryListener) listener).onServiceLost(
-                            (NsdServiceInfo) message.obj));
+                            (NsdServiceInfo) obj));
                     break;
                 case STOP_DISCOVERY_FAILED:
                     // TODO: failure to stop discovery should be internal and retried internally, as
                     // the effect for the client is indistinguishable from STOP_DISCOVERY_SUCCEEDED
                     removeListener(key);
                     executor.execute(() -> ((DiscoveryListener) listener).onStopDiscoveryFailed(
-                            getNsdServiceInfoType(ns), message.arg1));
+                            getNsdServiceInfoType(ns), errorCode));
                     break;
                 case STOP_DISCOVERY_SUCCEEDED:
                     removeListener(key);
@@ -698,33 +715,33 @@
                 case REGISTER_SERVICE_FAILED:
                     removeListener(key);
                     executor.execute(() -> ((RegistrationListener) listener).onRegistrationFailed(
-                            ns, message.arg1));
+                            ns, errorCode));
                     break;
                 case REGISTER_SERVICE_SUCCEEDED:
                     executor.execute(() -> ((RegistrationListener) listener).onServiceRegistered(
-                            (NsdServiceInfo) message.obj));
+                            (NsdServiceInfo) obj));
                     break;
                 case UNREGISTER_SERVICE_FAILED:
                     removeListener(key);
                     executor.execute(() -> ((RegistrationListener) listener).onUnregistrationFailed(
-                            ns, message.arg1));
+                            ns, errorCode));
                     break;
                 case UNREGISTER_SERVICE_SUCCEEDED:
                     // TODO: do not unregister listener until service is unregistered, or provide
                     // alternative way for unregistering ?
-                    removeListener(message.arg2);
+                    removeListener(key);
                     executor.execute(() -> ((RegistrationListener) listener).onServiceUnregistered(
                             ns));
                     break;
                 case RESOLVE_SERVICE_FAILED:
                     removeListener(key);
                     executor.execute(() -> ((ResolveListener) listener).onResolveFailed(
-                            ns, message.arg1));
+                            ns, errorCode));
                     break;
                 case RESOLVE_SERVICE_SUCCEEDED:
                     removeListener(key);
                     executor.execute(() -> ((ResolveListener) listener).onServiceResolved(
-                            (NsdServiceInfo) message.obj));
+                            (NsdServiceInfo) obj));
                     break;
                 default:
                     Log.d(TAG, "Ignored " + message);
diff --git a/service-t/src/com/android/server/ethernet/EthernetTracker.java b/service-t/src/com/android/server/ethernet/EthernetTracker.java
index 709b774..1ab7515 100644
--- a/service-t/src/com/android/server/ethernet/EthernetTracker.java
+++ b/service-t/src/com/android/server/ethernet/EthernetTracker.java
@@ -229,7 +229,7 @@
      */
     protected void broadcastInterfaceStateChange(@NonNull String iface) {
         ensureRunningOnEthernetServiceThread();
-        final int state = mFactory.getInterfaceState(iface);
+        final int state = getInterfaceState(iface);
         final int role = getInterfaceRole(iface);
         final IpConfiguration config = getIpConfigurationForCallback(iface, state);
         final int n = mListeners.beginBroadcast();
@@ -436,15 +436,34 @@
         if (mDefaultInterface != null) {
             removeInterface(mDefaultInterface);
             addInterface(mDefaultInterface);
+            // when this broadcast is sent, any calls to notifyTetheredInterfaceAvailable or
+            // notifyTetheredInterfaceUnavailable have already happened
+            broadcastInterfaceStateChange(mDefaultInterface);
         }
     }
 
+    private int getInterfaceState(final String iface) {
+        if (mFactory.hasInterface(iface)) {
+            return mFactory.getInterfaceState(iface);
+        }
+        if (getInterfaceMode(iface) == INTERFACE_MODE_SERVER) {
+            // server mode interfaces are not tracked by the factory.
+            // TODO(b/234743836): interface state for server mode interfaces is not tracked
+            // properly; just return link up.
+            return EthernetManager.STATE_LINK_UP;
+        }
+        return EthernetManager.STATE_ABSENT;
+    }
+
     private int getInterfaceRole(final String iface) {
-        if (!mFactory.hasInterface(iface)) return EthernetManager.ROLE_NONE;
-        final int mode = getInterfaceMode(iface);
-        return (mode == INTERFACE_MODE_CLIENT)
-                ? EthernetManager.ROLE_CLIENT
-                : EthernetManager.ROLE_SERVER;
+        if (mFactory.hasInterface(iface)) {
+            // only client mode interfaces are tracked by the factory.
+            return EthernetManager.ROLE_CLIENT;
+        }
+        if (getInterfaceMode(iface) == INTERFACE_MODE_SERVER) {
+            return EthernetManager.ROLE_SERVER;
+        }
+        return EthernetManager.ROLE_NONE;
     }
 
     private int getInterfaceMode(final String iface) {
diff --git a/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt b/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
index bfc9b29..293da67 100644
--- a/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
@@ -20,6 +20,16 @@
 import android.Manifest.permission.NETWORK_SETTINGS
 import android.content.Context
 import android.net.ConnectivityManager
+import android.net.EthernetManager
+import android.net.EthernetManager.InterfaceStateListener
+import android.net.EthernetManager.ROLE_CLIENT
+import android.net.EthernetManager.ROLE_NONE
+import android.net.EthernetManager.ROLE_SERVER
+import android.net.EthernetManager.STATE_ABSENT
+import android.net.EthernetManager.STATE_LINK_DOWN
+import android.net.EthernetManager.STATE_LINK_UP
+import android.net.EthernetManager.TetheredInterfaceCallback
+import android.net.EthernetManager.TetheredInterfaceRequest
 import android.net.EthernetNetworkSpecifier
 import android.net.InetAddresses
 import android.net.IpConfiguration
@@ -32,47 +42,47 @@
 import android.net.TestNetworkInterface
 import android.net.TestNetworkManager
 import android.net.cts.EthernetManagerTest.EthernetStateListener.CallbackEntry.InterfaceStateChanged
+import android.os.Build
 import android.os.Handler
 import android.os.HandlerExecutor
 import android.os.Looper
+import android.os.SystemProperties
 import android.platform.test.annotations.AppModeFull
 import android.util.ArraySet
 import androidx.test.platform.app.InstrumentationRegistry
-import androidx.test.runner.AndroidJUnit4
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.net.module.util.TrackRecord
-import com.android.networkstack.apishim.EthernetManagerShimImpl
-import com.android.networkstack.apishim.common.EthernetManagerShim.InterfaceStateListener
-import com.android.networkstack.apishim.common.EthernetManagerShim.ROLE_CLIENT
-import com.android.networkstack.apishim.common.EthernetManagerShim.ROLE_NONE
-import com.android.networkstack.apishim.common.EthernetManagerShim.STATE_ABSENT
-import com.android.networkstack.apishim.common.EthernetManagerShim.STATE_LINK_DOWN
-import com.android.networkstack.apishim.common.EthernetManagerShim.STATE_LINK_UP
 import com.android.testutils.anyNetwork
 import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.RecorderCallback.CallbackEntry.Available
 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
 import com.android.testutils.RouterAdvertisementResponder
-import com.android.testutils.SC_V2
 import com.android.testutils.TapPacketReader
 import com.android.testutils.TestableNetworkCallback
 import com.android.testutils.runAsShell
 import com.android.testutils.waitForIdle
 import org.junit.After
+import org.junit.Assume.assumeFalse
 import org.junit.Before
-import org.junit.Rule
 import org.junit.Ignore
 import org.junit.Test
 import org.junit.runner.RunWith
 import java.net.Inet6Address
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.ExecutionException
+import java.util.concurrent.TimeUnit
 import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
 import kotlin.test.assertTrue
 import kotlin.test.fail
 
-private const val TIMEOUT_MS = 1000L
+// TODO: try to lower this timeout in the future. Currently, ethernet tests are still flaky because
+// the interface is not ready fast enough (mostly due to the up / up / down / up issue).
+private const val TIMEOUT_MS = 2000L
 private const val NO_CALLBACK_TIMEOUT_MS = 200L
 private val DEFAULT_IP_CONFIGURATION = IpConfiguration(IpConfiguration.IpAssignment.DHCP,
     IpConfiguration.ProxySettings.NONE, null, null)
@@ -83,14 +93,13 @@
     .build()
 
 @AppModeFull(reason = "Instant apps can't access EthernetManager")
-@RunWith(AndroidJUnit4::class)
+// EthernetManager is not updatable before T, so tests do not need to be backwards compatible.
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class EthernetManagerTest {
-    // EthernetManager is not updatable before T, so tests do not need to be backwards compatible
-    @get:Rule
-    val ignoreRule = DevSdkIgnoreRule(ignoreClassUpTo = SC_V2)
 
     private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
-    private val em by lazy { EthernetManagerShimImpl.newInstance(context) }
+    private val em by lazy { context.getSystemService(EthernetManager::class.java) }
     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java) }
 
     private val ifaceListener = EthernetStateListener()
@@ -98,6 +107,8 @@
     private val addedListeners = ArrayList<EthernetStateListener>()
     private val networkRequests = ArrayList<TestableNetworkCallback>()
 
+    private var tetheredInterfaceRequest: TetheredInterfaceRequest? = null
+
     private class EthernetTestInterface(
         context: Context,
         private val handler: Handler
@@ -162,11 +173,11 @@
         }
 
         fun expectCallback(iface: EthernetTestInterface, state: Int, role: Int) {
-            expectCallback(createChangeEvent(iface, state, role))
+            expectCallback(createChangeEvent(iface.interfaceName, state, role))
         }
 
-        fun createChangeEvent(iface: EthernetTestInterface, state: Int, role: Int) =
-                InterfaceStateChanged(iface.interfaceName, state, role,
+        fun createChangeEvent(iface: String, state: Int, role: Int) =
+                InterfaceStateChanged(iface, state, role,
                         if (state != STATE_ABSENT) DEFAULT_IP_CONFIGURATION else null)
 
         fun pollForNextCallback(): CallbackEntry {
@@ -175,8 +186,12 @@
 
         fun eventuallyExpect(expected: CallbackEntry) = events.poll(TIMEOUT_MS) { it == expected }
 
+        fun eventuallyExpect(interfaceName: String, state: Int, role: Int) {
+            assertNotNull(eventuallyExpect(createChangeEvent(interfaceName, state, role)))
+        }
+
         fun eventuallyExpect(iface: EthernetTestInterface, state: Int, role: Int) {
-            assertNotNull(eventuallyExpect(createChangeEvent(iface, state, role)))
+            eventuallyExpect(iface.interfaceName, state, role)
         }
 
         fun assertNoCallback() {
@@ -185,6 +200,34 @@
         }
     }
 
+    private class TetheredInterfaceListener : TetheredInterfaceCallback {
+        private val available = CompletableFuture<String>()
+
+        override fun onAvailable(iface: String) {
+            available.complete(iface)
+        }
+
+        override fun onUnavailable() {
+            available.completeExceptionally(IllegalStateException("onUnavailable was called"))
+        }
+
+        fun expectOnAvailable(): String {
+            return available.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+        }
+
+        fun expectOnUnavailable() {
+            // Assert that the future fails with the IllegalStateException from the
+            // completeExceptionally() call inside onUnavailable.
+            assertFailsWith(IllegalStateException::class) {
+                try {
+                    available.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+                } catch (e: ExecutionException) {
+                    throw e.cause!!
+                }
+            }
+        }
+    }
+
     @Before
     fun setUp() {
         setIncludeTestInterfaces(true)
@@ -202,6 +245,7 @@
             em.removeInterfaceStateListener(listener)
         }
         networkRequests.forEach { cm.unregisterNetworkCallback(it) }
+        releaseTetheredInterface()
     }
 
     private fun addInterfaceStateListener(listener: EthernetStateListener) {
@@ -248,6 +292,19 @@
         networkRequests.remove(cb)
     }
 
+    private fun requestTetheredInterface() = TetheredInterfaceListener().also {
+        tetheredInterfaceRequest = runAsShell(NETWORK_SETTINGS) {
+            em.requestTetheredInterface(HandlerExecutor(Handler(Looper.getMainLooper())), it)
+        }
+    }
+
+    private fun releaseTetheredInterface() {
+        runAsShell(NETWORK_SETTINGS) {
+            tetheredInterfaceRequest?.release()
+            tetheredInterfaceRequest = null
+        }
+    }
+
     private fun NetworkRequest.createCopyWithEthernetSpecifier(ifaceName: String) =
         NetworkRequest.Builder(NetworkRequest(ETH_REQUEST))
             .setNetworkSpecifier(EthernetNetworkSpecifier(ifaceName)).build()
@@ -301,6 +358,35 @@
         }
     }
 
+    // TODO: this function is now used in two places (EthernetManagerTest and
+    // EthernetTetheringTest), so it should be moved to testutils.
+    private fun isAdbOverNetwork(): Boolean {
+        // If adb TCP port opened, this test may running by adb over network.
+        return (SystemProperties.getInt("persist.adb.tcp.port", -1) > -1 ||
+                SystemProperties.getInt("service.adb.tcp.port", -1) > -1)
+    }
+
+    @Ignore("TODO: temporarily ignore tests until prebuilts are updated")
+    @Test
+    fun testCallbacks_forServerModeInterfaces() {
+        // do not run this test when adb might be connected over ethernet.
+        assumeFalse(isAdbOverNetwork())
+
+        val listener = EthernetStateListener()
+        addInterfaceStateListener(listener)
+
+        // it is possible that a physical interface is present, so it is not guaranteed that iface
+        // will be put into server mode. This should not matter for the test though. Calling
+        // createInterface() makes sure we have at least one interface available.
+        val iface = createInterface()
+        val cb = requestTetheredInterface()
+        val ifaceName = cb.expectOnAvailable()
+        listener.eventuallyExpect(ifaceName, STATE_LINK_UP, ROLE_SERVER)
+
+        releaseTetheredInterface()
+        listener.eventuallyExpect(ifaceName, STATE_LINK_UP, ROLE_CLIENT)
+    }
+
     /**
      * Validate all interfaces are returned for an EthernetStateListener upon registration.
      */
@@ -316,7 +402,10 @@
             assertTrue(ifaces.contains(iface), "Untracked interface $iface returned")
             // If the event's iface was created in the test, additional criteria can be validated.
             createdIfaces.find { it.interfaceName.equals(iface) }?.let {
-                assertEquals(event, listener.createChangeEvent(it, STATE_LINK_UP, ROLE_CLIENT))
+                assertEquals(event,
+                    listener.createChangeEvent(it.interfaceName,
+                                                        STATE_LINK_UP,
+                                                        ROLE_CLIENT))
             }
         }
         // Assert all callbacks are accounted for.
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 33a0a83..69ec189 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -22,6 +22,7 @@
 import android.net.Network
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
 import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
 import android.net.NetworkCapabilities.TRANSPORT_TEST
 import android.net.NetworkRequest
@@ -45,7 +46,9 @@
 import android.net.nsd.NsdManager.RegistrationListener
 import android.net.nsd.NsdManager.ResolveListener
 import android.net.nsd.NsdServiceInfo
+import android.os.Handler
 import android.os.HandlerThread
+import android.os.Process.myTid
 import android.platform.test.annotations.AppModeFull
 import android.util.Log
 import androidx.test.platform.app.InstrumentationRegistry
@@ -62,6 +65,7 @@
 import org.junit.Assert.assertTrue
 import org.junit.Assume.assumeTrue
 import org.junit.Before
+import org.junit.Ignore
 import org.junit.Test
 import org.junit.runner.RunWith
 import java.net.ServerSocket
@@ -111,12 +115,20 @@
 
     private interface NsdEvent
     private open class NsdRecord<T : NsdEvent> private constructor(
-        private val history: ArrayTrackRecord<T>
+        private val history: ArrayTrackRecord<T>,
+        private val expectedThreadId: Int? = null
     ) : TrackRecord<T> by history {
-        constructor() : this(ArrayTrackRecord())
+        constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
 
         val nextEvents = history.newReadHead()
 
+        override fun add(e: T): Boolean {
+            if (expectedThreadId != null) {
+                assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread")
+            }
+            return history.add(e)
+        }
+
         inline fun <reified V : NsdEvent> expectCallbackEventually(
             crossinline predicate: (V) -> Boolean = { true }
         ): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V?
@@ -136,8 +148,8 @@
         }
     }
 
-    private class NsdRegistrationRecord : RegistrationListener,
-            NsdRecord<NsdRegistrationRecord.RegistrationEvent>() {
+    private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener,
+            NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
         sealed class RegistrationEvent : NsdEvent {
             abstract val serviceInfo: NsdServiceInfo
 
@@ -174,8 +186,8 @@
         }
     }
 
-    private class NsdDiscoveryRecord : DiscoveryListener,
-            NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>() {
+    private class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
+            DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
         sealed class DiscoveryEvent : NsdEvent {
             data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int)
                 : DiscoveryEvent()
@@ -452,7 +464,7 @@
         }
     }
 
-    @Test
+    @Test @Ignore // TODO(b/234099453): re-enable when the prebuilt module is updated
     fun testNsdManager_DiscoverWithNetworkRequest() {
         // This test requires shims supporting T+ APIs (discovering on network request)
         assumeTrue(TestUtils.shouldTestTApis())
@@ -462,9 +474,12 @@
         si.serviceName = this.serviceName
         si.port = 12345 // Test won't try to connect so port does not matter
 
-        val registrationRecord = NsdRegistrationRecord()
-        val registeredInfo1 = registerService(registrationRecord, si)
-        val discoveryRecord = NsdDiscoveryRecord()
+        val handler = Handler(handlerThread.looper)
+        val executor = Executor { handler.post(it) }
+
+        val registrationRecord = NsdRegistrationRecord(expectedThreadId = handlerThread.threadId)
+        val registeredInfo1 = registerService(registrationRecord, si, executor)
+        val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
 
         tryTest {
             val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
@@ -474,7 +489,7 @@
                             .addTransportType(TRANSPORT_TEST)
                             .setNetworkSpecifier(specifier)
                             .build(),
-                    Executor { it.run() }, discoveryRecord)
+                    executor, discoveryRecord)
 
             val discoveryStarted = discoveryRecord.expectCallback<DiscoveryStarted>()
             assertEquals(SERVICE_TYPE, discoveryStarted.serviceType)
@@ -490,7 +505,7 @@
             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo))
 
             registrationRecord.expectCallback<ServiceUnregistered>()
-            val registeredInfo2 = registerService(registrationRecord, si)
+            val registeredInfo2 = registerService(registrationRecord, si, executor)
             val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo))
@@ -517,6 +532,39 @@
         }
     }
 
+    @Test @Ignore // TODO(b/234099453): re-enable when the prebuilt module is updated
+    fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
+        // This test requires shims supporting T+ APIs (discovering on network request)
+        assumeTrue(TestUtils.shouldTestTApis())
+
+        val si = NsdServiceInfo()
+        si.serviceType = SERVICE_TYPE
+        si.serviceName = this.serviceName
+        si.port = 12345 // Test won't try to connect so port does not matter
+
+        val handler = Handler(handlerThread.looper)
+        val executor = Executor { handler.post(it) }
+
+        val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
+        val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
+
+        tryTest {
+            nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
+                    NetworkRequest.Builder()
+                            .removeCapability(NET_CAPABILITY_TRUSTED)
+                            .addTransportType(TRANSPORT_TEST)
+                            // Specified network does not have this capability
+                            .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                            .setNetworkSpecifier(specifier)
+                            .build(),
+                    executor, discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStarted>()
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
     @Test
     fun testNsdManager_ResolveOnNetwork() {
         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
@@ -648,9 +696,12 @@
     /**
      * Register a service and return its registration record.
      */
-    private fun registerService(record: NsdRegistrationRecord, si: NsdServiceInfo): NsdServiceInfo {
-        nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, Executor { it.run() },
-                record)
+    private fun registerService(
+        record: NsdRegistrationRecord,
+        si: NsdServiceInfo,
+        executor: Executor = Executor { it.run() }
+    ): NsdServiceInfo {
+        nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, executor, record)
         // We may not always get the name that we tried to register;
         // This events tells us the name that was registered.
         val cb = record.expectCallback<ServiceRegistered>()
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index 11fbcb9..eb35469 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -45,6 +45,7 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
@@ -111,6 +112,7 @@
 import android.os.INetworkManagementService;
 import android.os.Looper;
 import android.os.ParcelFileDescriptor;
+import android.os.PowerWhitelistManager;
 import android.os.Process;
 import android.os.UserHandle;
 import android.os.UserManager;
@@ -129,6 +131,7 @@
 import com.android.internal.net.VpnProfile;
 import com.android.internal.util.HexDump;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.server.DeviceIdleInternal;
 import com.android.server.IpSecService;
 import com.android.server.vcn.util.PersistableBundleUtils;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -235,6 +238,7 @@
     @Mock private ConnectivityManager mConnectivityManager;
     @Mock private IpSecService mIpSecService;
     @Mock private VpnProfileStore mVpnProfileStore;
+    @Mock DeviceIdleInternal mDeviceIdleInternal;
     private final VpnProfile mVpnProfile;
 
     private IpSecManager mIpSecManager;
@@ -408,6 +412,12 @@
                 disallow);
     }
 
+    private void verifyPowerSaveTempWhitelistApp(String packageName) {
+        verify(mDeviceIdleInternal).addPowerSaveTempWhitelistApp(anyInt(), eq(packageName),
+                anyLong(), anyInt(), eq(false), eq(PowerWhitelistManager.REASON_VPN),
+                eq("VpnManager event"));
+    }
+
     @Test
     public void testGetAlwaysAndOnGetLockDown() throws Exception {
         final Vpn vpn = createVpn(primaryUser.id);
@@ -1144,6 +1154,8 @@
         verifyPlatformVpnIsActivated(TEST_VPN_PKG);
         vpn.stopVpnProfile(TEST_VPN_PKG);
         verifyPlatformVpnIsDeactivated(TEST_VPN_PKG);
+        verifyPowerSaveTempWhitelistApp(TEST_VPN_PKG);
+        reset(mDeviceIdleInternal);
         // CATEGORY_EVENT_DEACTIVATED_BY_USER is not an error event, so both of errorClass and
         // errorCode won't be set.
         verifyVpnManagerEvent(sessionKey1, VpnManager.CATEGORY_EVENT_DEACTIVATED_BY_USER,
@@ -1155,6 +1167,8 @@
         verifyPlatformVpnIsActivated(TEST_VPN_PKG);
         vpn.prepare(TEST_VPN_PKG, "com.new.vpn" /* newPackage */, TYPE_VPN_PLATFORM);
         verifyPlatformVpnIsDeactivated(TEST_VPN_PKG);
+        verifyPowerSaveTempWhitelistApp(TEST_VPN_PKG);
+        reset(mDeviceIdleInternal);
         // CATEGORY_EVENT_DEACTIVATED_BY_USER is not an error event, so both of errorClass and
         // errorCode won't be set.
         verifyVpnManagerEvent(sessionKey2, VpnManager.CATEGORY_EVENT_DEACTIVATED_BY_USER,
@@ -1170,6 +1184,8 @@
         // Enable VPN always-on for PKGS[1].
         assertTrue(vpn.setAlwaysOnPackage(PKGS[1], false /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[1]);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(null /* sessionKey */,
                 VpnManager.CATEGORY_EVENT_ALWAYS_ON_STATE_CHANGED, -1 /* errorClass */,
                 -1 /* errorCode */, new VpnProfileState(VpnProfileState.STATE_DISCONNECTED,
@@ -1178,6 +1194,8 @@
         // Enable VPN lockdown for PKGS[1].
         assertTrue(vpn.setAlwaysOnPackage(PKGS[1], true /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[1]);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(null /* sessionKey */,
                 VpnManager.CATEGORY_EVENT_ALWAYS_ON_STATE_CHANGED, -1 /* errorClass */,
                 -1 /* errorCode */, new VpnProfileState(VpnProfileState.STATE_DISCONNECTED,
@@ -1186,6 +1204,8 @@
         // Disable VPN lockdown for PKGS[1].
         assertTrue(vpn.setAlwaysOnPackage(PKGS[1], false /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[1]);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(null /* sessionKey */,
                 VpnManager.CATEGORY_EVENT_ALWAYS_ON_STATE_CHANGED, -1 /* errorClass */,
                 -1 /* errorCode */, new VpnProfileState(VpnProfileState.STATE_DISCONNECTED,
@@ -1194,6 +1214,8 @@
         // Disable VPN always-on.
         assertTrue(vpn.setAlwaysOnPackage(null, false /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[1]);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(null /* sessionKey */,
                 VpnManager.CATEGORY_EVENT_ALWAYS_ON_STATE_CHANGED, -1 /* errorClass */,
                 -1 /* errorCode */, new VpnProfileState(VpnProfileState.STATE_DISCONNECTED,
@@ -1202,6 +1224,8 @@
         // Enable VPN always-on for PKGS[1] again.
         assertTrue(vpn.setAlwaysOnPackage(PKGS[1], false /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[1]);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(null /* sessionKey */,
                 VpnManager.CATEGORY_EVENT_ALWAYS_ON_STATE_CHANGED, -1 /* errorClass */,
                 -1 /* errorCode */, new VpnProfileState(VpnProfileState.STATE_DISCONNECTED,
@@ -1210,6 +1234,8 @@
         // Enable VPN always-on for PKGS[2].
         assertTrue(vpn.setAlwaysOnPackage(PKGS[2], false /* lockdown */,
                 null /* lockdownAllowlist */));
+        verifyPowerSaveTempWhitelistApp(PKGS[2]);
+        reset(mDeviceIdleInternal);
         // PKGS[1] is replaced with PKGS[2].
         // Pass 2 VpnProfileState objects to verifyVpnManagerEvent(), the first one is sent to
         // PKGS[1] to notify PKGS[1] that the VPN always-on is disabled, the second one is sent to
@@ -1310,6 +1336,8 @@
         final IkeSessionCallback ikeCb = captor.getValue();
         ikeCb.onClosedWithException(exception);
 
+        verifyPowerSaveTempWhitelistApp(TEST_VPN_PKG);
+        reset(mDeviceIdleInternal);
         verifyVpnManagerEvent(sessionKey, category, errorType, errorCode, null /* profileState */);
         if (errorType == VpnManager.ERROR_CLASS_NOT_RECOVERABLE) {
             verify(mConnectivityManager, timeout(TEST_TIMEOUT_MS))
@@ -1532,7 +1560,7 @@
         }
     }
 
-    private static final class TestDeps extends Vpn.Dependencies {
+    private final class TestDeps extends Vpn.Dependencies {
         public final CompletableFuture<String[]> racoonArgs = new CompletableFuture();
         public final CompletableFuture<String[]> mtpdArgs = new CompletableFuture();
         public final File mStateFile;
@@ -1661,6 +1689,11 @@
 
         @Override
         public void setBlocking(FileDescriptor fd, boolean blocking) {}
+
+        @Override
+        public DeviceIdleInternal getDeviceIdleInternal() {
+            return mDeviceIdleInternal;
+        }
     }
 
     /**
diff --git a/tests/unit/java/com/android/server/ethernet/EthernetTrackerTest.java b/tests/unit/java/com/android/server/ethernet/EthernetTrackerTest.java
index 115f0e1..e90d55d 100644
--- a/tests/unit/java/com/android/server/ethernet/EthernetTrackerTest.java
+++ b/tests/unit/java/com/android/server/ethernet/EthernetTrackerTest.java
@@ -28,6 +28,7 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
@@ -68,6 +69,7 @@
 
 import java.net.InetAddress;
 import java.util.ArrayList;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -445,7 +447,20 @@
         when(mNetd.interfaceGetList()).thenReturn(new String[] {testIface});
         when(mNetd.interfaceGetCfg(eq(testIface))).thenReturn(ifaceParcel);
         doReturn(new String[] {testIface}).when(mFactory).getAvailableInterfaces(anyBoolean());
-        doReturn(EthernetManager.STATE_LINK_UP).when(mFactory).getInterfaceState(eq(testIface));
+
+        final AtomicBoolean ifaceUp = new AtomicBoolean(true);
+        doAnswer(inv -> ifaceUp.get()).when(mFactory).hasInterface(testIface);
+        doAnswer(inv ->
+                ifaceUp.get() ? EthernetManager.STATE_LINK_UP : EthernetManager.STATE_ABSENT)
+                .when(mFactory).getInterfaceState(testIface);
+        doAnswer(inv -> {
+            ifaceUp.set(true);
+            return null;
+        }).when(mFactory).addInterface(eq(testIface), eq(testHwAddr), any(), any());
+        doAnswer(inv -> {
+            ifaceUp.set(false);
+            return null;
+        }).when(mFactory).removeInterface(testIface);
 
         final EthernetStateListener listener = spy(new EthernetStateListener());
         tracker.addListener(listener, true /* canUseRestrictedNetworks */);
@@ -456,7 +471,6 @@
         verify(listener).onEthernetStateChanged(eq(EthernetManager.ETHERNET_STATE_ENABLED));
         reset(listener);
 
-        doReturn(EthernetManager.STATE_ABSENT).when(mFactory).getInterfaceState(eq(testIface));
         tracker.setEthernetEnabled(false);
         waitForIdle();
         verify(mFactory).removeInterface(eq(testIface));
@@ -465,7 +479,6 @@
                 anyInt(), any());
         reset(listener);
 
-        doReturn(EthernetManager.STATE_LINK_UP).when(mFactory).getInterfaceState(eq(testIface));
         tracker.setEthernetEnabled(true);
         waitForIdle();
         verify(mFactory).addInterface(eq(testIface), eq(testHwAddr), any(), any());