Merge changes I968bfa76,Id46f1e5b,Iff9b6212,I6bdb090a

* changes:
  Use TestConnectivityManager in TetheringTest.
  Support building different UpstreamNetworkState test objects.
  Change TetheringTest's UpstreamNetworkMonitor from mock to spy.
  Make TestConnectivityManager usable by other tethering tests.
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java
new file mode 100644
index 0000000..3a6350c
--- /dev/null
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java
@@ -0,0 +1,247 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.networkstack.tethering;
+
+import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.fail;
+
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.IConnectivityManager;
+import android.net.LinkProperties;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
+import android.os.Handler;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * Simulates upstream switching and sending NetworkCallbacks and CONNECTIVITY_ACTION broadcasts.
+ *
+ * TODO: this duplicates a fair amount of code from ConnectivityManager and ConnectivityService.
+ * Consider using a ConnectivityService object instead, as used in ConnectivityServiceTest.
+ *
+ * Things to consider:
+ * - ConnectivityService uses a real handler for realism, and these test use TestLooper (or even
+ *   invoke callbacks directly inline) for determinism. Using a real ConnectivityService would
+ *   require adding dispatchAll() calls and migrating to handlers.
+ * - ConnectivityService does not provide a way to order CONNECTIVITY_ACTION before or after the
+ *   NetworkCallbacks for the same network change. That ability is useful because the upstream
+ *   selection code in Tethering is vulnerable to race conditions, due to its reliance on multiple
+ *   separate NetworkCallbacks and BroadcastReceivers, each of which trigger different types of
+ *   updates. If/when the upstream selection code is refactored to a more level-triggered model
+ *   (e.g., with an idempotent function that takes into account all state every time any part of
+ *   that state changes), this may become less important or unnecessary.
+ */
+public class TestConnectivityManager extends ConnectivityManager {
+    public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
+    public Set<NetworkCallback> trackingDefault = new HashSet<>();
+    public TestNetworkAgent defaultNetwork = null;
+    public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
+    public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
+    public Map<NetworkCallback, Integer> legacyTypeMap = new HashMap<>();
+
+    private final NetworkRequest mDefaultRequest;
+    private int mNetworkId = 100;
+
+    public TestConnectivityManager(Context ctx, IConnectivityManager svc,
+            NetworkRequest defaultRequest) {
+        super(ctx, svc);
+        mDefaultRequest = defaultRequest;
+    }
+
+    boolean hasNoCallbacks() {
+        return allCallbacks.isEmpty()
+                && trackingDefault.isEmpty()
+                && listening.isEmpty()
+                && requested.isEmpty()
+                && legacyTypeMap.isEmpty();
+    }
+
+    boolean onlyHasDefaultCallbacks() {
+        return (allCallbacks.size() == 1)
+                && (trackingDefault.size() == 1)
+                && listening.isEmpty()
+                && requested.isEmpty()
+                && legacyTypeMap.isEmpty();
+    }
+
+    boolean isListeningForAll() {
+        final NetworkCapabilities empty = new NetworkCapabilities();
+        empty.clearAll();
+
+        for (NetworkRequest req : listening.values()) {
+            if (req.networkCapabilities.equalRequestableCapabilities(empty)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    int getNetworkId() {
+        return ++mNetworkId;
+    }
+
+    void makeDefaultNetwork(TestNetworkAgent agent) {
+        if (Objects.equals(defaultNetwork, agent)) return;
+
+        final TestNetworkAgent formerDefault = defaultNetwork;
+        defaultNetwork = agent;
+
+        for (NetworkCallback cb : trackingDefault) {
+            if (defaultNetwork != null) {
+                cb.onAvailable(defaultNetwork.networkId);
+                cb.onCapabilitiesChanged(
+                        defaultNetwork.networkId, defaultNetwork.networkCapabilities);
+                cb.onLinkPropertiesChanged(
+                        defaultNetwork.networkId, defaultNetwork.linkProperties);
+            }
+        }
+    }
+
+    @Override
+    public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
+        assertFalse(allCallbacks.containsKey(cb));
+        allCallbacks.put(cb, h);
+        if (mDefaultRequest.equals(req)) {
+            assertFalse(trackingDefault.contains(cb));
+            trackingDefault.add(cb);
+        } else {
+            assertFalse(requested.containsKey(cb));
+            requested.put(cb, req);
+        }
+    }
+
+    @Override
+    public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
+        fail("Should never be called.");
+    }
+
+    @Override
+    public void requestNetwork(NetworkRequest req,
+            int timeoutMs, int legacyType, Handler h, NetworkCallback cb) {
+        assertFalse(allCallbacks.containsKey(cb));
+        allCallbacks.put(cb, h);
+        assertFalse(requested.containsKey(cb));
+        requested.put(cb, req);
+        assertFalse(legacyTypeMap.containsKey(cb));
+        if (legacyType != ConnectivityManager.TYPE_NONE) {
+            legacyTypeMap.put(cb, legacyType);
+        }
+    }
+
+    @Override
+    public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
+        assertFalse(allCallbacks.containsKey(cb));
+        allCallbacks.put(cb, h);
+        assertFalse(listening.containsKey(cb));
+        listening.put(cb, req);
+    }
+
+    @Override
+    public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
+        fail("Should never be called.");
+    }
+
+    @Override
+    public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
+        fail("Should never be called.");
+    }
+
+    @Override
+    public void registerDefaultNetworkCallback(NetworkCallback cb) {
+        fail("Should never be called.");
+    }
+
+    @Override
+    public void unregisterNetworkCallback(NetworkCallback cb) {
+        if (trackingDefault.contains(cb)) {
+            trackingDefault.remove(cb);
+        } else if (listening.containsKey(cb)) {
+            listening.remove(cb);
+        } else if (requested.containsKey(cb)) {
+            requested.remove(cb);
+            legacyTypeMap.remove(cb);
+        } else {
+            fail("Unexpected callback removed");
+        }
+        allCallbacks.remove(cb);
+
+        assertFalse(allCallbacks.containsKey(cb));
+        assertFalse(trackingDefault.contains(cb));
+        assertFalse(listening.containsKey(cb));
+        assertFalse(requested.containsKey(cb));
+    }
+
+    public static class TestNetworkAgent {
+        public final TestConnectivityManager cm;
+        public final Network networkId;
+        public final int transportType;
+        public final NetworkCapabilities networkCapabilities;
+        public final LinkProperties linkProperties;
+
+        public TestNetworkAgent(TestConnectivityManager cm, int transportType) {
+            this.cm = cm;
+            this.networkId = new Network(cm.getNetworkId());
+            this.transportType = transportType;
+            networkCapabilities = new NetworkCapabilities();
+            networkCapabilities.addTransportType(transportType);
+            networkCapabilities.addCapability(NET_CAPABILITY_INTERNET);
+            linkProperties = new LinkProperties();
+        }
+
+        public void fakeConnect() {
+            for (NetworkCallback cb : cm.listening.keySet()) {
+                cb.onAvailable(networkId);
+                cb.onCapabilitiesChanged(networkId, copy(networkCapabilities));
+                cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
+            }
+        }
+
+        public void fakeDisconnect() {
+            for (NetworkCallback cb : cm.listening.keySet()) {
+                cb.onLost(networkId);
+            }
+        }
+
+        public void sendLinkProperties() {
+            for (NetworkCallback cb : cm.listening.keySet()) {
+                cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
+            }
+        }
+
+        @Override
+        public String toString() {
+            return String.format("TestNetworkAgent: %s %s", networkId, networkCapabilities);
+        }
+    }
+
+    static NetworkCapabilities copy(NetworkCapabilities nc) {
+        return new NetworkCapabilities(nc);
+    }
+
+    static LinkProperties copy(LinkProperties lp) {
+        return new LinkProperties(lp);
+    }
+}
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index c5a43b7..40b7c55 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -25,6 +25,9 @@
 import static android.net.ConnectivityManager.ACTION_RESTRICT_BACKGROUND_CHANGED;
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_DISABLED;
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_ENABLED;
+import static android.net.ConnectivityManager.TYPE_NONE;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
 import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
@@ -64,6 +67,7 @@
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.argThat;
@@ -72,6 +76,7 @@
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -94,10 +99,11 @@
 import android.content.pm.PackageManager;
 import android.content.res.Resources;
 import android.hardware.usb.UsbManager;
-import android.net.ConnectivityManager;
+import android.net.ConnectivityManager.NetworkCallback;
 import android.net.EthernetManager;
 import android.net.EthernetManager.TetheredInterfaceCallback;
 import android.net.EthernetManager.TetheredInterfaceRequest;
+import android.net.IConnectivityManager;
 import android.net.IIntResultListener;
 import android.net.INetd;
 import android.net.ITetheringEventCallback;
@@ -200,6 +206,10 @@
     private static final String[] PROVISIONING_APP_NAME = {"some", "app"};
     private static final String PROVISIONING_NO_UI_APP_NAME = "no_ui_app";
 
+    private static final int CELLULAR_NETID = 100;
+    private static final int WIFI_NETID = 101;
+    private static final int DUN_NETID = 102;
+
     private static final int DHCPSERVER_START_TIMEOUT_MS = 1000;
 
     @Mock private ApplicationInfo mApplicationInfo;
@@ -212,7 +222,6 @@
     @Mock private UsbManager mUsbManager;
     @Mock private WifiManager mWifiManager;
     @Mock private CarrierConfigManager mCarrierConfigManager;
-    @Mock private UpstreamNetworkMonitor mUpstreamNetworkMonitor;
     @Mock private IPv6TetheringCoordinator mIPv6TetheringCoordinator;
     @Mock private DadProxy mDadProxy;
     @Mock private RouterAdvertisementDaemon mRouterAdvertisementDaemon;
@@ -220,8 +229,6 @@
     @Mock private IDhcpServer mDhcpServer;
     @Mock private INetd mNetd;
     @Mock private UserManager mUserManager;
-    @Mock private NetworkRequest mNetworkRequest;
-    @Mock private ConnectivityManager mCm;
     @Mock private EthernetManager mEm;
     @Mock private TetheringNotificationUpdater mNotificationUpdater;
     @Mock private BpfCoordinator mBpfCoordinator;
@@ -249,6 +256,11 @@
     private OffloadController mOffloadCtrl;
     private PrivateAddressCoordinator mPrivateAddressCoordinator;
     private SoftApCallback mSoftApCallback;
+    private UpstreamNetworkMonitor mUpstreamNetworkMonitor;
+
+    private TestConnectivityManager mCm;
+    private NetworkRequest mNetworkRequest;
+    private NetworkCallback mDefaultNetworkCallback;
 
     private class TestContext extends BroadcastInterceptingContext {
         TestContext(Context base) {
@@ -400,7 +412,10 @@
         @Override
         public UpstreamNetworkMonitor getUpstreamNetworkMonitor(Context ctx,
                 StateMachine target, SharedLog log, int what) {
+            // Use a real object instead of a mock so that some tests can use a real UNM and some
+            // can use a mock.
             mUpstreamNetworkMonitorSM = target;
+            mUpstreamNetworkMonitor = spy(super.getUpstreamNetworkMonitor(ctx, target, log, what));
             return mUpstreamNetworkMonitor;
         }
 
@@ -480,15 +495,15 @@
         }
     }
 
-    private static UpstreamNetworkState buildMobileUpstreamState(boolean withIPv4,
-            boolean withIPv6, boolean with464xlat) {
+    private static LinkProperties buildUpstreamLinkProperties(String interfaceName,
+            boolean withIPv4, boolean withIPv6, boolean with464xlat) {
         final LinkProperties prop = new LinkProperties();
-        prop.setInterfaceName(TEST_MOBILE_IFNAME);
+        prop.setInterfaceName(interfaceName);
 
         if (withIPv4) {
             prop.addRoute(new RouteInfo(new IpPrefix(Inet4Address.ANY, 0),
                     InetAddresses.parseNumericAddress("10.0.0.1"),
-                    TEST_MOBILE_IFNAME, RTN_UNICAST));
+                    interfaceName, RTN_UNICAST));
         }
 
         if (withIPv6) {
@@ -498,23 +513,40 @@
                             NetworkConstants.RFC7421_PREFIX_LENGTH));
             prop.addRoute(new RouteInfo(new IpPrefix(Inet6Address.ANY, 0),
                     InetAddresses.parseNumericAddress("2001:db8::1"),
-                    TEST_MOBILE_IFNAME, RTN_UNICAST));
+                    interfaceName, RTN_UNICAST));
         }
 
         if (with464xlat) {
+            final String clatInterface = "v4-" + interfaceName;
             final LinkProperties stackedLink = new LinkProperties();
-            stackedLink.setInterfaceName(TEST_XLAT_MOBILE_IFNAME);
+            stackedLink.setInterfaceName(clatInterface);
             stackedLink.addRoute(new RouteInfo(new IpPrefix(Inet4Address.ANY, 0),
                     InetAddresses.parseNumericAddress("192.0.0.1"),
-                    TEST_XLAT_MOBILE_IFNAME, RTN_UNICAST));
+                    clatInterface, RTN_UNICAST));
 
             prop.addStackedLink(stackedLink);
         }
 
+        return prop;
+    }
 
-        final NetworkCapabilities capabilities = new NetworkCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);
-        return new UpstreamNetworkState(prop, capabilities, new Network(100));
+    private static NetworkCapabilities buildUpstreamCapabilities(int transport, int... otherCaps) {
+        // TODO: add NOT_VCN_MANAGED.
+        final NetworkCapabilities nc = new NetworkCapabilities()
+                .addTransportType(transport)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET);
+        for (int cap : otherCaps) {
+            nc.addCapability(cap);
+        }
+        return nc;
+    }
+
+    private static UpstreamNetworkState buildMobileUpstreamState(boolean withIPv4,
+            boolean withIPv6, boolean with464xlat) {
+        return new UpstreamNetworkState(
+                buildUpstreamLinkProperties(TEST_MOBILE_IFNAME, withIPv4, withIPv6, with464xlat),
+                buildUpstreamCapabilities(TRANSPORT_CELLULAR),
+                new Network(CELLULAR_NETID));
     }
 
     private static UpstreamNetworkState buildMobileIPv4UpstreamState() {
@@ -533,6 +565,22 @@
         return buildMobileUpstreamState(false, true, true);
     }
 
+    private static UpstreamNetworkState buildWifiUpstreamState() {
+        return new UpstreamNetworkState(
+                buildUpstreamLinkProperties(TEST_WIFI_IFNAME, true /* IPv4 */, true /* IPv6 */,
+                        false /* 464xlat */),
+                buildUpstreamCapabilities(TRANSPORT_WIFI),
+                new Network(WIFI_NETID));
+    }
+
+    private static UpstreamNetworkState buildDunUpstreamState() {
+        return new UpstreamNetworkState(
+                buildUpstreamLinkProperties(TEST_MOBILE_IFNAME, true /* IPv4 */, true /* IPv6 */,
+                        false /* 464xlat */),
+                buildUpstreamCapabilities(TRANSPORT_CELLULAR, NET_CAPABILITY_DUN),
+                new Network(DUN_NETID));
+    }
+
     // See FakeSettingsProvider#clearSettingsProvider() that this needs to be called before and
     // after use.
     @BeforeClass
@@ -578,9 +626,22 @@
         };
         mServiceContext.registerReceiver(mBroadcastReceiver,
                 new IntentFilter(ACTION_TETHER_STATE_CHANGED));
+
+        // TODO: add NOT_VCN_MANAGED here, but more importantly in the production code.
+        // TODO: even better, change TetheringDependencies.getDefaultNetworkRequest() to use
+        // registerSystemDefaultNetworkCallback() on S and above.
+        NetworkCapabilities defaultCaps = new NetworkCapabilities()
+                .addCapability(NET_CAPABILITY_INTERNET);
+        mNetworkRequest = new NetworkRequest(defaultCaps, TYPE_NONE, 1 /* requestId */,
+                NetworkRequest.Type.REQUEST);
+        mCm = spy(new TestConnectivityManager(mServiceContext, mock(IConnectivityManager.class),
+                mNetworkRequest));
+
         mTethering = makeTethering();
         verify(mStatsManager, times(1)).registerNetworkStatsProvider(anyString(), any());
         verify(mNetd).registerUnsolicitedEventListener(any());
+        verifyDefaultNetworkRequestFiled();
+
         final ArgumentCaptor<PhoneStateListener> phoneListenerCaptor =
                 ArgumentCaptor.forClass(PhoneStateListener.class);
         verify(mTelephonyManager).listen(phoneListenerCaptor.capture(),
@@ -617,8 +678,8 @@
     }
 
     private void initTetheringUpstream(UpstreamNetworkState upstreamState) {
-        when(mUpstreamNetworkMonitor.getCurrentPreferredUpstream()).thenReturn(upstreamState);
-        when(mUpstreamNetworkMonitor.selectPreferredUpstreamType(any())).thenReturn(upstreamState);
+        doReturn(upstreamState).when(mUpstreamNetworkMonitor).getCurrentPreferredUpstream();
+        doReturn(upstreamState).when(mUpstreamNetworkMonitor).selectPreferredUpstreamType(any());
     }
 
     private Tethering makeTethering() {
@@ -705,6 +766,19 @@
         mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
     }
 
+    private void verifyDefaultNetworkRequestFiled() {
+        ArgumentCaptor<NetworkCallback> captor = ArgumentCaptor.forClass(NetworkCallback.class);
+        verify(mCm, times(1)).requestNetwork(eq(mNetworkRequest),
+                captor.capture(), any(Handler.class));
+        mDefaultNetworkCallback = captor.getValue();
+        assertNotNull(mDefaultNetworkCallback);
+
+        // The default network request is only ever filed once.
+        verifyNoMoreInteractions(mCm);
+        mUpstreamNetworkMonitor.startTrackDefaultNetwork(mNetworkRequest, mEntitleMgr);
+        verifyNoMoreInteractions(mCm);
+    }
+
     private void verifyInterfaceServingModeStarted(String ifname) throws Exception {
         verify(mNetd, times(1)).interfaceSetCfg(any(InterfaceConfigurationParcel.class));
         verify(mNetd, times(1)).tetherInterfaceAdd(ifname);
@@ -1723,12 +1797,12 @@
     }
 
     private void setDataSaverEnabled(boolean enabled) {
-        final Intent intent = new Intent(ACTION_RESTRICT_BACKGROUND_CHANGED);
-        mServiceContext.sendBroadcastAsUser(intent, UserHandle.ALL);
-
         final int status = enabled ? RESTRICT_BACKGROUND_STATUS_ENABLED
                 : RESTRICT_BACKGROUND_STATUS_DISABLED;
-        when(mCm.getRestrictBackgroundStatus()).thenReturn(status);
+        doReturn(status).when(mCm).getRestrictBackgroundStatus();
+
+        final Intent intent = new Intent(ACTION_RESTRICT_BACKGROUND_CHANGED);
+        mServiceContext.sendBroadcastAsUser(intent, UserHandle.ALL);
         mLooper.dispatchAll();
     }
 
@@ -1882,7 +1956,8 @@
         // Verify that onUpstreamCapabilitiesChanged won't be called if not current upstream network
         // capabilities changed.
         final UpstreamNetworkState upstreamState2 = new UpstreamNetworkState(
-                upstreamState.linkProperties, upstreamState.networkCapabilities, new Network(101));
+                upstreamState.linkProperties, upstreamState.networkCapabilities,
+                new Network(WIFI_NETID));
         stateMachine.handleUpstreamNetworkMonitorCallback(EVENT_ON_CAPABILITIES, upstreamState2);
         verify(mNotificationUpdater, never()).onUpstreamCapabilitiesChanged(any());
     }
@@ -1992,7 +2067,7 @@
     public void testHandleIpConflict() throws Exception {
         final Network wifiNetwork = new Network(200);
         final Network[] allNetworks = { wifiNetwork };
-        when(mCm.getAllNetworks()).thenReturn(allNetworks);
+        doReturn(allNetworks).when(mCm).getAllNetworks();
         runUsbTethering(null);
         final ArgumentCaptor<InterfaceConfigurationParcel> ifaceConfigCaptor =
                 ArgumentCaptor.forClass(InterfaceConfigurationParcel.class);
@@ -2019,7 +2094,7 @@
         final Network btNetwork = new Network(201);
         final Network mobileNetwork = new Network(202);
         final Network[] allNetworks = { wifiNetwork, btNetwork, mobileNetwork };
-        when(mCm.getAllNetworks()).thenReturn(allNetworks);
+        doReturn(allNetworks).when(mCm).getAllNetworks();
         runUsbTethering(null);
         verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS).times(1)).startWithCallbacks(
                 any(), any());
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
index 232588c..e358f5a 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
@@ -29,7 +29,6 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.anyString;
@@ -48,7 +47,6 @@
 import android.net.IpPrefix;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
-import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.util.SharedLog;
@@ -60,6 +58,7 @@
 
 import com.android.internal.util.State;
 import com.android.internal.util.StateMachine;
+import com.android.networkstack.tethering.TestConnectivityManager.TestNetworkAgent;
 
 import org.junit.After;
 import org.junit.Before;
@@ -71,10 +70,7 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 
 @RunWith(AndroidJUnit4.class)
@@ -106,7 +102,7 @@
         when(mLog.forSubComponent(anyString())).thenReturn(mLog);
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
 
-        mCM = spy(new TestConnectivityManager(mContext, mCS));
+        mCM = spy(new TestConnectivityManager(mContext, mCS, sDefaultRequest));
         mSM = new TestStateMachine();
         mUNM = new UpstreamNetworkMonitor(
                 (ConnectivityManager) mCM, mSM, mLog, EVENT_UNM_UPDATE);
@@ -567,187 +563,6 @@
         return false;
     }
 
-    public static class TestConnectivityManager extends ConnectivityManager {
-        public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
-        public Set<NetworkCallback> trackingDefault = new HashSet<>();
-        public TestNetworkAgent defaultNetwork = null;
-        public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
-        public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
-        public Map<NetworkCallback, Integer> legacyTypeMap = new HashMap<>();
-
-        private int mNetworkId = 100;
-
-        public TestConnectivityManager(Context ctx, IConnectivityManager svc) {
-            super(ctx, svc);
-        }
-
-        boolean hasNoCallbacks() {
-            return allCallbacks.isEmpty()
-                    && trackingDefault.isEmpty()
-                    && listening.isEmpty()
-                    && requested.isEmpty()
-                    && legacyTypeMap.isEmpty();
-        }
-
-        boolean onlyHasDefaultCallbacks() {
-            return (allCallbacks.size() == 1)
-                    && (trackingDefault.size() == 1)
-                    && listening.isEmpty()
-                    && requested.isEmpty()
-                    && legacyTypeMap.isEmpty();
-        }
-
-        boolean isListeningForAll() {
-            final NetworkCapabilities empty = new NetworkCapabilities();
-            empty.clearAll();
-
-            for (NetworkRequest req : listening.values()) {
-                if (req.networkCapabilities.equalRequestableCapabilities(empty)) {
-                    return true;
-                }
-            }
-            return false;
-        }
-
-        int getNetworkId() {
-            return ++mNetworkId;
-        }
-
-        void makeDefaultNetwork(TestNetworkAgent agent) {
-            if (Objects.equals(defaultNetwork, agent)) return;
-
-            final TestNetworkAgent formerDefault = defaultNetwork;
-            defaultNetwork = agent;
-
-            for (NetworkCallback cb : trackingDefault) {
-                if (defaultNetwork != null) {
-                    cb.onAvailable(defaultNetwork.networkId);
-                    cb.onCapabilitiesChanged(
-                            defaultNetwork.networkId, defaultNetwork.networkCapabilities);
-                    cb.onLinkPropertiesChanged(
-                            defaultNetwork.networkId, defaultNetwork.linkProperties);
-                }
-            }
-        }
-
-        @Override
-        public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
-            assertFalse(allCallbacks.containsKey(cb));
-            allCallbacks.put(cb, h);
-            if (sDefaultRequest.equals(req)) {
-                assertFalse(trackingDefault.contains(cb));
-                trackingDefault.add(cb);
-            } else {
-                assertFalse(requested.containsKey(cb));
-                requested.put(cb, req);
-            }
-        }
-
-        @Override
-        public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
-            fail("Should never be called.");
-        }
-
-        @Override
-        public void requestNetwork(NetworkRequest req,
-                int timeoutMs, int legacyType, Handler h, NetworkCallback cb) {
-            assertFalse(allCallbacks.containsKey(cb));
-            allCallbacks.put(cb, h);
-            assertFalse(requested.containsKey(cb));
-            requested.put(cb, req);
-            assertFalse(legacyTypeMap.containsKey(cb));
-            if (legacyType != ConnectivityManager.TYPE_NONE) {
-                legacyTypeMap.put(cb, legacyType);
-            }
-        }
-
-        @Override
-        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
-            assertFalse(allCallbacks.containsKey(cb));
-            allCallbacks.put(cb, h);
-            assertFalse(listening.containsKey(cb));
-            listening.put(cb, req);
-        }
-
-        @Override
-        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
-            fail("Should never be called.");
-        }
-
-        @Override
-        public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
-            fail("Should never be called.");
-        }
-
-        @Override
-        public void registerDefaultNetworkCallback(NetworkCallback cb) {
-            fail("Should never be called.");
-        }
-
-        @Override
-        public void unregisterNetworkCallback(NetworkCallback cb) {
-            if (trackingDefault.contains(cb)) {
-                trackingDefault.remove(cb);
-            } else if (listening.containsKey(cb)) {
-                listening.remove(cb);
-            } else if (requested.containsKey(cb)) {
-                requested.remove(cb);
-                legacyTypeMap.remove(cb);
-            } else {
-                fail("Unexpected callback removed");
-            }
-            allCallbacks.remove(cb);
-
-            assertFalse(allCallbacks.containsKey(cb));
-            assertFalse(trackingDefault.contains(cb));
-            assertFalse(listening.containsKey(cb));
-            assertFalse(requested.containsKey(cb));
-        }
-    }
-
-    public static class TestNetworkAgent {
-        public final TestConnectivityManager cm;
-        public final Network networkId;
-        public final int transportType;
-        public final NetworkCapabilities networkCapabilities;
-        public final LinkProperties linkProperties;
-
-        public TestNetworkAgent(TestConnectivityManager cm, int transportType) {
-            this.cm = cm;
-            this.networkId = new Network(cm.getNetworkId());
-            this.transportType = transportType;
-            networkCapabilities = new NetworkCapabilities();
-            networkCapabilities.addTransportType(transportType);
-            networkCapabilities.addCapability(NET_CAPABILITY_INTERNET);
-            linkProperties = new LinkProperties();
-        }
-
-        public void fakeConnect() {
-            for (NetworkCallback cb : cm.listening.keySet()) {
-                cb.onAvailable(networkId);
-                cb.onCapabilitiesChanged(networkId, copy(networkCapabilities));
-                cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
-            }
-        }
-
-        public void fakeDisconnect() {
-            for (NetworkCallback cb : cm.listening.keySet()) {
-                cb.onLost(networkId);
-            }
-        }
-
-        public void sendLinkProperties() {
-            for (NetworkCallback cb : cm.listening.keySet()) {
-                cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
-            }
-        }
-
-        @Override
-        public String toString() {
-            return String.format("TestNetworkAgent: %s %s", networkId, networkCapabilities);
-        }
-    }
-
     public static class TestStateMachine extends StateMachine {
         public final ArrayList<Message> messages = new ArrayList<>();
         private final State mLoggingState = new LoggingState();
@@ -775,14 +590,6 @@
         }
     }
 
-    static NetworkCapabilities copy(NetworkCapabilities nc) {
-        return new NetworkCapabilities(nc);
-    }
-
-    static LinkProperties copy(LinkProperties lp) {
-        return new LinkProperties(lp);
-    }
-
     static void assertPrefixSet(Set<IpPrefix> prefixes, boolean expectation, String... expected) {
         final Set<String> expectedSet = new HashSet<>();
         Collections.addAll(expectedSet, expected);
@@ -797,4 +604,4 @@
                     expectation, prefixes.contains(new IpPrefix(expectedPrefix)));
         }
     }
-}
+}
\ No newline at end of file