Merge "Revert "Add service-connectivity to tethering APEX""
diff --git a/tests/cts/net/TEST_MAPPING b/TEST_MAPPING
similarity index 88%
rename from tests/cts/net/TEST_MAPPING
rename to TEST_MAPPING
index 8f65b65..1db4baa 100644
--- a/tests/cts/net/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -1,5 +1,6 @@
 {
-  // TODO: move to mainline-presubmit once supported
+  // Run in addition to mainline-presubmit as mainline-presubmit is not
+  // supported in every branch.
   "presubmit": [
     {
       "name": "CtsNetTestCasesLatestSdk",
diff --git a/Tethering/Android.bp b/Tethering/Android.bp
index 23aa7f8..d8557ad 100644
--- a/Tethering/Android.bp
+++ b/Tethering/Android.bp
@@ -67,6 +67,9 @@
         "liblog",
         "libnativehelper_compat_libc++",
     ],
+    static_libs: [
+        "libnetjniutils",
+    ],
 
     // We cannot use plain "libc++" here to link libc++ dynamically because it results in:
     //   java.lang.UnsatisfiedLinkError: dlopen failed: library "libc++_shared.so" not found
diff --git a/Tethering/jni/android_net_util_TetheringUtils.cpp b/Tethering/jni/android_net_util_TetheringUtils.cpp
index 94c871d..7bfb6da 100644
--- a/Tethering/jni/android_net_util_TetheringUtils.cpp
+++ b/Tethering/jni/android_net_util_TetheringUtils.cpp
@@ -19,8 +19,8 @@
 #include <jni.h>
 #include <linux/filter.h>
 #include <nativehelper/JNIHelp.h>
-#include <nativehelper/JNIHelpCompat.h>
 #include <nativehelper/ScopedUtfChars.h>
+#include <netjniutils/netjniutils.h>
 #include <net/if.h>
 #include <netinet/ether.h>
 #include <netinet/ip6.h>
@@ -57,7 +57,7 @@
         filter_code,
     };
 
-    int fd = jniGetFDFromFileDescriptor(env, javaFd);
+    int fd = netjniutils::GetNativeFileDescriptor(env, javaFd);
     if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_FILTER, &filter, sizeof(filter)) != 0) {
         jniThrowExceptionFmt(env, "java/net/SocketException",
                 "setsockopt(SO_ATTACH_FILTER): %s", strerror(errno));
@@ -79,7 +79,7 @@
 {
     static const int kLinkLocalHopLimit = 255;
 
-    int fd = jniGetFDFromFileDescriptor(env, javaFd);
+    int fd = netjniutils::GetNativeFileDescriptor(env, javaFd);
 
     // Set an ICMPv6 filter that only passes Router Solicitations.
     struct icmp6_filter rs_only;
diff --git a/Tethering/src/android/net/ip/NeighborPacketForwarder.java b/Tethering/src/android/net/ip/NeighborPacketForwarder.java
index 73fc833..084743d 100644
--- a/Tethering/src/android/net/ip/NeighborPacketForwarder.java
+++ b/Tethering/src/android/net/ip/NeighborPacketForwarder.java
@@ -25,7 +25,6 @@
 import static android.system.OsConstants.SOCK_RAW;
 
 import android.net.util.InterfaceParams;
-import android.net.util.PacketReader;
 import android.net.util.SocketUtils;
 import android.net.util.TetheringUtils;
 import android.os.Handler;
@@ -33,6 +32,8 @@
 import android.system.Os;
 import android.util.Log;
 
+import com.android.net.module.util.PacketReader;
+
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.Inet6Address;
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index 2c91d10..62ae88c 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -140,8 +140,8 @@
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.Iterator;
-import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.Executor;
@@ -216,13 +216,12 @@
     private final ArrayMap<String, TetherState> mTetherStates;
     private final BroadcastReceiver mStateReceiver;
     private final Looper mLooper;
-    private final StateMachine mTetherMainSM;
+    private final TetherMainSM mTetherMainSM;
     private final OffloadController mOffloadController;
     private final UpstreamNetworkMonitor mUpstreamNetworkMonitor;
     // TODO: Figure out how to merge this and other downstream-tracking objects
     // into a single coherent structure.
-    // Use LinkedHashSet for predictable ordering order for ConnectedClientsTracker.
-    private final LinkedHashSet<IpServer> mForwardedDownstreams;
+    private final HashSet<IpServer> mForwardedDownstreams;
     private final VersionedBroadcastListener mCarrierConfigChange;
     private final TetheringDependencies mDeps;
     private final EntitlementManager mEntitlementMgr;
@@ -287,7 +286,7 @@
                 });
         mUpstreamNetworkMonitor = mDeps.getUpstreamNetworkMonitor(mContext, mTetherMainSM, mLog,
                 TetherMainSM.EVENT_UPSTREAM_CALLBACK);
-        mForwardedDownstreams = new LinkedHashSet<>();
+        mForwardedDownstreams = new HashSet<>();
 
         IntentFilter filter = new IntentFilter();
         filter.addAction(ACTION_CARRIER_CONFIG_CHANGED);
@@ -1423,6 +1422,7 @@
         //    interfaces.
         // 2) mNotifyList contains all state machines that may have outstanding tethering state
         //    that needs to be torn down.
+        // 3) Use mNotifyList for predictable ordering order for ConnectedClientsTracker.
         //
         // Because we excise interfaces immediately from mTetherStates, we must maintain mNotifyList
         // so that the garbage collector does not clean up the state machine before it has a chance
@@ -1459,6 +1459,15 @@
             setInitialState(mInitialState);
         }
 
+        /**
+         * Returns all downstreams that are serving clients, regardless of they are actually
+         * tethered or localOnly. This must be called on the tethering thread (not thread-safe).
+         */
+        @NonNull
+        public List<IpServer> getAllDownstreams() {
+            return mNotifyList;
+        }
+
         class InitialState extends State {
             @Override
             public boolean processMessage(Message message) {
@@ -2300,7 +2309,8 @@
     }
 
     private void updateConnectedClients(final List<WifiClient> wifiClients) {
-        if (mConnectedClientsTracker.updateConnectedClients(mForwardedDownstreams, wifiClients)) {
+        if (mConnectedClientsTracker.updateConnectedClients(mTetherMainSM.getAllDownstreams(),
+                wifiClients)) {
             reportTetherClientsChanged(mConnectedClientsTracker.getLastTetheredClients());
         }
     }
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index 0a37f54..f4b3749 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -16,6 +16,7 @@
 
 package com.android.networkstack.tethering;
 
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.content.pm.PackageManager.GET_ACTIVITIES;
 import static android.hardware.usb.UsbManager.USB_CONFIGURED;
 import static android.hardware.usb.UsbManager.USB_CONNECTED;
@@ -36,6 +37,7 @@
 import static android.net.TetheringManager.TETHERING_NCM;
 import static android.net.TetheringManager.TETHERING_USB;
 import static android.net.TetheringManager.TETHERING_WIFI;
+import static android.net.TetheringManager.TETHERING_WIFI_P2P;
 import static android.net.TetheringManager.TETHER_ERROR_IFACE_CFG_ERROR;
 import static android.net.TetheringManager.TETHER_ERROR_NO_ERROR;
 import static android.net.TetheringManager.TETHER_ERROR_UNKNOWN_IFACE;
@@ -49,12 +51,15 @@
 import static android.net.wifi.WifiManager.IFACE_IP_MODE_LOCAL_ONLY;
 import static android.net.wifi.WifiManager.IFACE_IP_MODE_TETHERED;
 import static android.net.wifi.WifiManager.WIFI_AP_STATE_ENABLED;
+import static android.system.OsConstants.RT_SCOPE_UNIVERSE;
 import static android.telephony.SubscriptionManager.INVALID_SUBSCRIPTION_ID;
 
+import static com.android.net.module.util.Inet4AddressUtils.inet4AddressToIntHTH;
 import static com.android.net.module.util.Inet4AddressUtils.intToInet4AddressHTH;
 import static com.android.networkstack.tethering.Tethering.UserRestrictionActionListener;
 import static com.android.networkstack.tethering.TetheringNotificationUpdater.DOWNSTREAM_NONE;
 import static com.android.networkstack.tethering.UpstreamNetworkMonitor.EVENT_ON_CAPABILITIES;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -108,11 +113,14 @@
 import android.net.RouteInfo;
 import android.net.TetherStatesParcel;
 import android.net.TetheredClient;
+import android.net.TetheredClient.AddressInfo;
 import android.net.TetheringCallbackStartedParcel;
 import android.net.TetheringConfigurationParcel;
 import android.net.TetheringRequestParcel;
+import android.net.dhcp.DhcpLeaseParcelable;
 import android.net.dhcp.DhcpServerCallbacks;
 import android.net.dhcp.DhcpServingParamsParcel;
+import android.net.dhcp.IDhcpEventCallbacks;
 import android.net.dhcp.IDhcpServer;
 import android.net.ip.DadProxy;
 import android.net.ip.IpNeighborMonitor;
@@ -122,7 +130,9 @@
 import android.net.util.NetworkConstants;
 import android.net.util.SharedLog;
 import android.net.wifi.SoftApConfiguration;
+import android.net.wifi.WifiClient;
 import android.net.wifi.WifiManager;
+import android.net.wifi.WifiManager.SoftApCallback;
 import android.net.wifi.p2p.WifiP2pGroup;
 import android.net.wifi.p2p.WifiP2pInfo;
 import android.net.wifi.p2p.WifiP2pManager;
@@ -168,6 +178,7 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Vector;
 
@@ -237,6 +248,7 @@
     private EntitlementManager mEntitleMgr;
     private OffloadController mOffloadCtrl;
     private PrivateAddressCoordinator mPrivateAddressCoordinator;
+    private SoftApCallback mSoftApCallback;
 
     private class TestContext extends BroadcastInterceptingContext {
         TestContext(Context base) {
@@ -568,8 +580,12 @@
                 ArgumentCaptor.forClass(PhoneStateListener.class);
         verify(mTelephonyManager).listen(phoneListenerCaptor.capture(),
                 eq(PhoneStateListener.LISTEN_ACTIVE_DATA_SUBSCRIPTION_ID_CHANGE));
-        verify(mWifiManager).registerSoftApCallback(any(), any());
         mPhoneStateListener = phoneListenerCaptor.getValue();
+
+        final ArgumentCaptor<SoftApCallback> softApCallbackCaptor =
+                ArgumentCaptor.forClass(SoftApCallback.class);
+        verify(mWifiManager).registerSoftApCallback(any(), softApCallbackCaptor.capture());
+        mSoftApCallback = softApCallbackCaptor.getValue();
     }
 
     private void setTetheringSupported(final boolean supported) {
@@ -1293,6 +1309,7 @@
                 new ArrayList<>();
         private final ArrayList<TetherStatesParcel> mTetherStates = new ArrayList<>();
         private final ArrayList<Integer> mOffloadStatus = new ArrayList<>();
+        private final ArrayList<List<TetheredClient>> mTetheredClients = new ArrayList<>();
 
         // This function will remove the recorded callbacks, so it must be called once for
         // each callback. If this is called after multiple callback, the order matters.
@@ -1338,6 +1355,13 @@
             return mTetherStates.remove(0);
         }
 
+        public void expectTetheredClientChanged(List<TetheredClient> leases) {
+            assertFalse(mTetheredClients.isEmpty());
+            final List<TetheredClient> result = mTetheredClients.remove(0);
+            assertEquals(leases.size(), result.size());
+            assertTrue(leases.containsAll(result));
+        }
+
         @Override
         public void onUpstreamChanged(Network network) {
             mActualUpstreams.add(network);
@@ -1355,7 +1379,7 @@
 
         @Override
         public void onTetherClientsChanged(List<TetheredClient> clients) {
-            // TODO: check this
+            mTetheredClients.add(clients);
         }
 
         @Override
@@ -1369,6 +1393,7 @@
             mTetheringConfigs.add(parcel.config);
             mTetherStates.add(parcel.states);
             mOffloadStatus.add(parcel.offloadStatus);
+            mTetheredClients.add(parcel.tetheredClients);
         }
 
         @Override
@@ -1398,6 +1423,7 @@
             assertNoUpstreamChangeCallback();
             assertNoConfigChangeCallback();
             assertNoStateChangeCallback();
+            assertTrue(mTetheredClients.isEmpty());
         }
 
         private void assertTetherConfigParcelEqual(@NonNull TetheringConfigurationParcel actual,
@@ -1437,6 +1463,7 @@
         // 1. Register one callback before running any tethering.
         mTethering.registerTetheringEventCallback(callback);
         mLooper.dispatchAll();
+        callback.expectTetheredClientChanged(Collections.emptyList());
         callback.expectUpstreamChanged(new Network[] {null});
         callback.expectConfigurationChanged(
                 mTethering.getTetheringConfiguration().toStableParcelable());
@@ -1463,6 +1490,7 @@
         // 3. Register second callback.
         mTethering.registerTetheringEventCallback(callback2);
         mLooper.dispatchAll();
+        callback2.expectTetheredClientChanged(Collections.emptyList());
         callback2.expectUpstreamChanged(upstreamState.network);
         callback2.expectConfigurationChanged(
                 mTethering.getTetheringConfiguration().toStableParcelable());
@@ -2054,6 +2082,114 @@
         verify(mPackageManager).getPackageInfo(PROVISIONING_APP_NAME[0], GET_ACTIVITIES);
     }
 
+    @Test
+    public void testUpdateConnectedClients() throws Exception {
+        TestTetheringEventCallback callback = new TestTetheringEventCallback();
+        runAsShell(NETWORK_SETTINGS, () -> {
+            mTethering.registerTetheringEventCallback(callback);
+            mLooper.dispatchAll();
+        });
+        callback.expectTetheredClientChanged(Collections.emptyList());
+
+        IDhcpEventCallbacks eventCallbacks;
+        final ArgumentCaptor<IDhcpEventCallbacks> dhcpEventCbsCaptor =
+                 ArgumentCaptor.forClass(IDhcpEventCallbacks.class);
+        // Run local only tethering.
+        mTethering.interfaceStatusChanged(TEST_P2P_IFNAME, true);
+        sendWifiP2pConnectionChanged(true, true, TEST_P2P_IFNAME);
+        mLooper.dispatchAll();
+        verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
+                any(), dhcpEventCbsCaptor.capture());
+        eventCallbacks = dhcpEventCbsCaptor.getValue();
+        // Update lease for local only tethering.
+        final MacAddress testMac1 = MacAddress.fromString("11:11:11:11:11:11");
+        final ArrayList<DhcpLeaseParcelable> p2pLeases = new ArrayList<>();
+        p2pLeases.add(createDhcpLeaseParcelable("clientId1", testMac1, "192.168.50.24", 24,
+                Long.MAX_VALUE, "test1"));
+        notifyDhcpLeasesChanged(p2pLeases, eventCallbacks);
+        final List<TetheredClient> clients = toTetheredClients(p2pLeases, TETHERING_WIFI_P2P);
+        callback.expectTetheredClientChanged(clients);
+        reset(mDhcpServer);
+
+        // Run wifi tethering.
+        mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
+        mLooper.dispatchAll();
+        verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
+                any(), dhcpEventCbsCaptor.capture());
+        eventCallbacks = dhcpEventCbsCaptor.getValue();
+        // Update mac address from softAp callback before getting dhcp lease.
+        final ArrayList<WifiClient> wifiClients = new ArrayList<>();
+        final MacAddress testMac2 = MacAddress.fromString("22:22:22:22:22:22");
+        final WifiClient testClient = mock(WifiClient.class);
+        when(testClient.getMacAddress()).thenReturn(testMac2);
+        wifiClients.add(testClient);
+        mSoftApCallback.onConnectedClientsChanged(wifiClients);
+        final TetheredClient noAddrClient = new TetheredClient(testMac2,
+                Collections.emptyList() /* addresses */, TETHERING_WIFI);
+        clients.add(noAddrClient);
+        callback.expectTetheredClientChanged(clients);
+
+        // Update dhcp lease for wifi tethering.
+        clients.remove(noAddrClient);
+        final ArrayList<DhcpLeaseParcelable> wifiLeases = new ArrayList<>();
+        wifiLeases.add(createDhcpLeaseParcelable("clientId2", testMac2, "192.168.43.24", 24,
+                Long.MAX_VALUE, "test2"));
+        notifyDhcpLeasesChanged(wifiLeases, eventCallbacks);
+        clients.addAll(toTetheredClients(wifiLeases, TETHERING_WIFI));
+        callback.expectTetheredClientChanged(clients);
+
+        // Test onStarted callback that register second callback when tethering is running.
+        TestTetheringEventCallback callback2 = new TestTetheringEventCallback();
+        runAsShell(NETWORK_SETTINGS, () -> {
+            mTethering.registerTetheringEventCallback(callback2);
+            mLooper.dispatchAll();
+        });
+        callback2.expectTetheredClientChanged(clients);
+    }
+
+    private void notifyDhcpLeasesChanged(List<DhcpLeaseParcelable> leaseParcelables,
+            IDhcpEventCallbacks callback) throws Exception {
+        callback.onLeasesChanged(leaseParcelables);
+        mLooper.dispatchAll();
+    }
+
+    private List<TetheredClient> toTetheredClients(List<DhcpLeaseParcelable> leaseParcelables,
+            int type) throws Exception {
+        final ArrayList<TetheredClient> leases = new ArrayList<>();
+        for (DhcpLeaseParcelable lease : leaseParcelables) {
+            final LinkAddress address = new LinkAddress(
+                    intToInet4AddressHTH(lease.netAddr), lease.prefixLength,
+                    0 /* flags */, RT_SCOPE_UNIVERSE /* as per RFC6724#3.2 */,
+                    lease.expTime /* deprecationTime */, lease.expTime /* expirationTime */);
+
+            final MacAddress macAddress = MacAddress.fromBytes(lease.hwAddr);
+
+            final AddressInfo addressInfo = new TetheredClient.AddressInfo(address, lease.hostname);
+            leases.add(new TetheredClient(
+                    macAddress,
+                    Collections.singletonList(addressInfo),
+                    type));
+        }
+
+        return leases;
+    }
+
+    private DhcpLeaseParcelable createDhcpLeaseParcelable(final String clientId,
+            final MacAddress hwAddr, final String netAddr, final int prefixLength,
+            final long expTime, final String hostname) throws Exception {
+        final DhcpLeaseParcelable lease = new DhcpLeaseParcelable();
+        lease.clientId = clientId.getBytes();
+        lease.hwAddr = hwAddr.toByteArray();
+        lease.netAddr = inet4AddressToIntHTH(
+                (Inet4Address) InetAddresses.parseNumericAddress(netAddr));
+        lease.prefixLength = prefixLength;
+        lease.expTime = expTime;
+        lease.hostname = hostname;
+
+        return lease;
+    }
+
     // TODO: Test that a request for hotspot mode doesn't interfere with an
     // already operating tethering mode interface.
 }
diff --git a/framework/Android.bp b/framework/Android.bp
new file mode 100644
index 0000000..8db8d76
--- /dev/null
+++ b/framework/Android.bp
@@ -0,0 +1,29 @@
+//
+// Copyright (C) 2020 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// TODO: use a java_library in the bootclasspath instead
+filegroup {
+    name: "framework-connectivity-sources",
+    srcs: [
+        "src/**/*.java",
+        "src/**/*.aidl",
+    ],
+    path: "src",
+    visibility: [
+        "//frameworks/base",
+        "//packages/modules/Connectivity:__subpackages__",
+    ],
+}
\ No newline at end of file
diff --git a/framework/src/com/android/connectivity/aidl/INetworkAgent.aidl b/framework/src/com/android/connectivity/aidl/INetworkAgent.aidl
new file mode 100644
index 0000000..1af9e76
--- /dev/null
+++ b/framework/src/com/android/connectivity/aidl/INetworkAgent.aidl
@@ -0,0 +1,46 @@
+/**
+ * Copyright (c) 2020, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing perNmissions and
+ * limitations under the License.
+ */
+package com.android.connectivity.aidl;
+
+import android.net.NattKeepalivePacketData;
+import android.net.TcpKeepalivePacketData;
+
+import com.android.connectivity.aidl.INetworkAgentRegistry;
+
+/**
+ * Interface to notify NetworkAgent of connectivity events.
+ * @hide
+ */
+oneway interface INetworkAgent {
+    void onRegistered(in INetworkAgentRegistry registry);
+    void onDisconnected();
+    void onBandwidthUpdateRequested();
+    void onValidationStatusChanged(int validationStatus,
+            in @nullable String captivePortalUrl);
+    void onSaveAcceptUnvalidated(boolean acceptUnvalidated);
+    void onStartNattSocketKeepalive(int slot, int intervalDurationMs,
+        in NattKeepalivePacketData packetData);
+    void onStartTcpSocketKeepalive(int slot, int intervalDurationMs,
+        in TcpKeepalivePacketData packetData);
+    void onStopSocketKeepalive(int slot);
+    void onSignalStrengthThresholdsUpdated(in int[] thresholds);
+    void onPreventAutomaticReconnect();
+    void onAddNattKeepalivePacketFilter(int slot,
+        in NattKeepalivePacketData packetData);
+    void onAddTcpKeepalivePacketFilter(int slot,
+        in TcpKeepalivePacketData packetData);
+    void onRemoveKeepalivePacketFilter(int slot);
+}
diff --git a/framework/src/com/android/connectivity/aidl/INetworkAgentRegistry.aidl b/framework/src/com/android/connectivity/aidl/INetworkAgentRegistry.aidl
new file mode 100644
index 0000000..d42a340
--- /dev/null
+++ b/framework/src/com/android/connectivity/aidl/INetworkAgentRegistry.aidl
@@ -0,0 +1,36 @@
+/**
+ * Copyright (c) 2020, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing perNmissions and
+ * limitations under the License.
+ */
+package com.android.connectivity.aidl;
+
+import android.net.LinkProperties;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkInfo;
+
+/**
+ * Interface for NetworkAgents to send network network properties.
+ * @hide
+ */
+oneway interface INetworkAgentRegistry {
+    void sendNetworkCapabilities(in NetworkCapabilities nc);
+    void sendLinkProperties(in LinkProperties lp);
+    // TODO: consider replacing this by "markConnected()" and removing
+    void sendNetworkInfo(in NetworkInfo info);
+    void sendScore(int score);
+    void sendExplicitlySelected(boolean explicitlySelected, boolean acceptPartial);
+    void sendSocketKeepaliveEvent(int slot, int reason);
+    void sendUnderlyingNetworks(in @nullable List<Network> networks);
+}
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 85d0a2e..803c9d8 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -18,25 +18,32 @@
 import android.app.Instrumentation
 import android.content.Context
 import android.net.ConnectivityManager
+import android.net.InetAddresses
+import android.net.IpPrefix
 import android.net.KeepalivePacketData
 import android.net.LinkAddress
 import android.net.LinkProperties
+import android.net.NattKeepalivePacketData
 import android.net.Network
 import android.net.NetworkAgent
-import android.net.NetworkAgent.CMD_ADD_KEEPALIVE_PACKET_FILTER
-import android.net.NetworkAgent.CMD_PREVENT_AUTOMATIC_RECONNECT
-import android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER
-import android.net.NetworkAgent.CMD_REPORT_NETWORK_STATUS
-import android.net.NetworkAgent.CMD_SAVE_ACCEPT_UNVALIDATED
-import android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE
-import android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE
 import android.net.NetworkAgent.INVALID_NETWORK
 import android.net.NetworkAgent.VALID_NETWORK
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_CONGESTED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
+import android.net.NetworkCapabilities.TRANSPORT_TEST
+import android.net.NetworkCapabilities.TRANSPORT_VPN
 import android.net.NetworkInfo
 import android.net.NetworkProvider
 import android.net.NetworkRequest
+import android.net.RouteInfo
 import android.net.SocketKeepalive
 import android.net.StringNetworkSpecifier
 import android.net.Uri
@@ -51,15 +58,14 @@
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnValidationStatus
 import android.os.Build
-import android.os.Bundle
-import android.os.Handler
 import android.os.HandlerThread
 import android.os.Looper
 import android.os.Message
-import android.os.Messenger
+import android.util.DebugUtils.valueToString
 import androidx.test.InstrumentationRegistry
 import androidx.test.runner.AndroidJUnit4
-import com.android.internal.util.AsyncChannel
+import com.android.connectivity.aidl.INetworkAgent
+import com.android.connectivity.aidl.INetworkAgentRegistry
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
@@ -68,7 +74,6 @@
 import com.android.testutils.TestableNetworkCallback
 import org.junit.After
 import org.junit.Assert.assertArrayEquals
-import org.junit.Assert.fail
 import org.junit.Before
 import org.junit.Rule
 import org.junit.Test
@@ -79,8 +84,8 @@
 import org.mockito.ArgumentMatchers.eq
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
+import org.mockito.Mockito.timeout
 import org.mockito.Mockito.verify
-import java.net.InetAddress
 import java.time.Duration
 import java.util.Arrays
 import java.util.UUID
@@ -90,6 +95,7 @@
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
 import kotlin.test.assertTrue
+import kotlin.test.fail
 
 // This test doesn't really have a constraint on how fast the methods should return. If it's
 // going to fail, it will simply wait forever, so setting a high timeout lowers the flake ratio
@@ -122,12 +128,12 @@
     @Rule @JvmField
     val ignoreRule = DevSdkIgnoreRule(ignoreClassUpTo = Build.VERSION_CODES.Q)
 
-    private val LOCAL_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.1")
-    private val REMOTE_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.2")
+    private val LOCAL_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1")
+    private val REMOTE_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.2")
 
     private val mCM = realContext.getSystemService(ConnectivityManager::class.java)
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
-    private val mFakeConnectivityService by lazy { FakeConnectivityService(mHandlerThread.looper) }
+    private val mFakeConnectivityService = FakeConnectivityService()
 
     private class Provider(context: Context, looper: Looper) :
             NetworkProvider(context, looper, "NetworkAgentTest NetworkProvider")
@@ -154,39 +160,30 @@
      * This fake only supports speaking to one harnessed agent at a time because it
      * only keeps track of one async channel.
      */
-    private class FakeConnectivityService(looper: Looper) {
-        private val CMD_EXPECT_DISCONNECT = 1
-        private var disconnectExpected = false
-        private val msgHistory = ArrayTrackRecord<Message>().newReadHead()
-        private val asyncChannel = AsyncChannel()
-        private val handler = object : Handler(looper) {
-            override fun handleMessage(msg: Message) {
-                msgHistory.add(Message.obtain(msg)) // make a copy as the original will be recycled
-                when (msg.what) {
-                    CMD_EXPECT_DISCONNECT -> disconnectExpected = true
-                    AsyncChannel.CMD_CHANNEL_HALF_CONNECTED ->
-                        asyncChannel.sendMessage(AsyncChannel.CMD_CHANNEL_FULL_CONNECTION)
-                    AsyncChannel.CMD_CHANNEL_DISCONNECTED ->
-                        if (!disconnectExpected) {
-                            fail("Agent unexpectedly disconnected")
-                        } else {
-                            disconnectExpected = false
-                        }
-                }
-            }
+    private class FakeConnectivityService {
+        val mockRegistry = mock(INetworkAgentRegistry::class.java)
+        private var agentField: INetworkAgent? = null
+        private val registry = object : INetworkAgentRegistry.Stub(),
+                INetworkAgentRegistry by mockRegistry {
+            // asBinder has implementations in both INetworkAgentRegistry.Stub and mockRegistry, so
+            // it needs to be disambiguated. Just fail the test as it should be unused here.
+            // asBinder is used when sending the registry in binder transactions, so not in this
+            // test (the test just uses in-process direct calls). If it were used across processes,
+            // using the Stub super.asBinder() implementation would allow sending the registry in
+            // binder transactions, while recording incoming calls on the other mockito-generated
+            // methods.
+            override fun asBinder() = fail("asBinder should be unused in this test")
         }
 
-        fun connect(agentMsngr: Messenger) = asyncChannel.connect(realContext, handler, agentMsngr)
+        val agent: INetworkAgent
+            get() = agentField ?: fail("No INetworkAgent")
 
-        fun disconnect() = asyncChannel.disconnect()
+        fun connect(agent: INetworkAgent) {
+            this.agentField = agent
+            agent.onRegistered(registry)
+        }
 
-        fun sendMessage(what: Int, arg1: Int = 0, arg2: Int = 0, obj: Any? = null) =
-            asyncChannel.sendMessage(Message(what, arg1, arg2, obj))
-
-        fun expectMessage(what: Int) =
-            assertNotNull(msgHistory.poll(DEFAULT_TIMEOUT_MS) { it.what == what })
-
-        fun willExpectDisconnectOnce() = handler.sendEmptyMessage(CMD_EXPECT_DISCONNECT)
+        fun disconnect() = agent.onDisconnected()
     }
 
     private open class TestableNetworkAgent(
@@ -314,22 +311,23 @@
     private fun createNetworkAgent(
         context: Context = realContext,
         name: String? = null,
-        nc: NetworkCapabilities = NetworkCapabilities(),
-        lp: LinkProperties = LinkProperties()
+        initialNc: NetworkCapabilities? = null,
+        initialLp: LinkProperties? = null
     ): TestableNetworkAgent {
-        nc.apply {
-            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+        val nc = initialNc ?: NetworkCapabilities().apply {
+            addTransportType(TRANSPORT_TEST)
+            removeCapability(NET_CAPABILITY_TRUSTED)
+            removeCapability(NET_CAPABILITY_INTERNET)
+            addCapability(NET_CAPABILITY_NOT_SUSPENDED)
+            addCapability(NET_CAPABILITY_NOT_ROAMING)
+            addCapability(NET_CAPABILITY_NOT_VPN)
             if (null != name) {
                 setNetworkSpecifier(StringNetworkSpecifier(name))
             }
         }
-        lp.apply {
-            addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0))
+        val lp = initialLp ?: LinkProperties().apply {
+            addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 32))
+            addRoute(RouteInfo(IpPrefix("0.0.0.0/0"), null, null))
         }
         val config = NetworkAgentConfig.Builder().build()
         return TestableNetworkAgent(context, mHandlerThread.looper, nc, lp, config).also {
@@ -341,7 +339,7 @@
             Pair<TestableNetworkAgent, TestableNetworkCallback> {
         val request: NetworkRequest = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -386,7 +384,7 @@
         val callbacks = thresholds.map { strength ->
             val request = NetworkRequest.Builder()
                     .clearCapabilities()
-                    .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                    .addTransportType(TRANSPORT_TEST)
                     .setSignalStrength(strength)
                     .build()
             TestableNetworkCallback(DEFAULT_TIMEOUT_MS).also {
@@ -431,17 +429,15 @@
 
     @Test
     fun testSocketKeepalive(): Unit = createNetworkAgentWithFakeCS().let { agent ->
-        val packet = object : KeepalivePacketData(
+        val packet = NattKeepalivePacketData(
                 LOCAL_IPV4_ADDRESS /* srcAddress */, 1234 /* srcPort */,
                 REMOTE_IPV4_ADDRESS /* dstAddress */, 4567 /* dstPort */,
-                ByteArray(100 /* size */) { it.toByte() /* init */ }) {}
+                ByteArray(100 /* size */))
         val slot = 4
         val interval = 37
 
-        mFakeConnectivityService.sendMessage(CMD_ADD_KEEPALIVE_PACKET_FILTER,
-                arg1 = slot, obj = packet)
-        mFakeConnectivityService.sendMessage(CMD_START_SOCKET_KEEPALIVE,
-                arg1 = slot, arg2 = interval, obj = packet)
+        mFakeConnectivityService.agent.onAddNattKeepalivePacketFilter(slot, packet)
+        mFakeConnectivityService.agent.onStartNattSocketKeepalive(slot, interval, packet)
 
         agent.expectCallback<OnAddKeepalivePacketFilter>().let {
             assertEquals(it.slot, slot)
@@ -458,13 +454,11 @@
         // Check that when the agent sends a keepalive event, ConnectivityService receives the
         // expected message.
         agent.sendSocketKeepaliveEvent(slot, SocketKeepalive.ERROR_UNSUPPORTED)
-        mFakeConnectivityService.expectMessage(NetworkAgent.EVENT_SOCKET_KEEPALIVE).let() {
-            assertEquals(slot, it.arg1)
-            assertEquals(SocketKeepalive.ERROR_UNSUPPORTED, it.arg2)
-        }
+        verify(mFakeConnectivityService.mockRegistry, timeout(DEFAULT_TIMEOUT_MS))
+                .sendSocketKeepaliveEvent(slot, SocketKeepalive.ERROR_UNSUPPORTED)
 
-        mFakeConnectivityService.sendMessage(CMD_STOP_SOCKET_KEEPALIVE, arg1 = slot)
-        mFakeConnectivityService.sendMessage(CMD_REMOVE_KEEPALIVE_PACKET_FILTER, arg1 = slot)
+        mFakeConnectivityService.agent.onStopSocketKeepalive(slot)
+        mFakeConnectivityService.agent.onRemoveKeepalivePacketFilter(slot)
         agent.expectCallback<OnStopSocketKeepalive>().let {
             assertEquals(it.slot, slot)
         }
@@ -486,10 +480,10 @@
             it.getInterfaceName() == ifaceName
         }
         val nc = NetworkCapabilities(agent.nc)
-        nc.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED)
+        nc.addCapability(NET_CAPABILITY_NOT_METERED)
         agent.sendNetworkCapabilities(nc)
         callback.expectCapabilitiesThat(agent.network) {
-            it.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED)
+            it.hasCapability(NET_CAPABILITY_NOT_METERED)
         }
     }
 
@@ -503,12 +497,12 @@
         val name2 = UUID.randomUUID().toString()
         val request1 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .setNetworkSpecifier(StringNetworkSpecifier(name1))
                 .build()
         val request2 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .setNetworkSpecifier(StringNetworkSpecifier(name2))
                 .build()
         val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
@@ -519,7 +513,7 @@
         // Then file the interesting request
         val request = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -551,52 +545,71 @@
     @IgnoreUpTo(Build.VERSION_CODES.R)
     fun testSetUnderlyingNetworks() {
         val request = NetworkRequest.Builder()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .addTransportType(NetworkCapabilities.TRANSPORT_VPN)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED) // TODO: add to VPN!
+                .addTransportType(TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_VPN)
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .removeCapability(NET_CAPABILITY_TRUSTED) // TODO: add to VPN!
                 .build()
-        val callback = TestableNetworkCallback()
-        mCM.registerNetworkCallback(request, callback)
+        val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        registerNetworkCallback(request, callback)
 
         val nc = NetworkCapabilities().apply {
-            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-            addTransportType(NetworkCapabilities.TRANSPORT_VPN)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+            addTransportType(TRANSPORT_TEST)
+            addTransportType(TRANSPORT_VPN)
+            removeCapability(NET_CAPABILITY_NOT_VPN)
         }
         val defaultNetwork = mCM.activeNetwork
         assertNotNull(defaultNetwork)
-        val defaultNetworkTransports = mCM.getNetworkCapabilities(defaultNetwork).transportTypes
+        val defaultNetworkCapabilities = mCM.getNetworkCapabilities(defaultNetwork)
+        val defaultNetworkTransports = defaultNetworkCapabilities.transportTypes
 
-        val agent = createNetworkAgent(nc = nc)
+        val agent = createNetworkAgent(initialNc = nc)
         agent.register()
         agent.markConnected()
         callback.expectAvailableThenValidatedCallbacks(agent.network!!)
 
+        // Check that the default network's transport is propagated to the VPN.
         var vpnNc = mCM.getNetworkCapabilities(agent.network)
         assertNotNull(vpnNc)
-        assertTrue(NetworkCapabilities.TRANSPORT_VPN in vpnNc.transportTypes)
+
+        val testAndVpn = intArrayOf(TRANSPORT_TEST, TRANSPORT_VPN)
+        assertTrue(hasAllTransports(vpnNc, testAndVpn))
+        assertFalse(vpnNc.hasCapability(NET_CAPABILITY_NOT_VPN))
         assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports),
                 "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" +
                 " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
 
+        // Check that when no underlying networks are announced the underlying transport disappears.
         agent.setUnderlyingNetworks(listOf<Network>())
         callback.expectCapabilitiesThat(agent.network!!) {
-            it.transportTypes.size == 1 && it.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
+            it.transportTypes.size == 2 && hasAllTransports(it, testAndVpn)
         }
 
-        val expectedTransports = (defaultNetworkTransports.toSet() +
-                NetworkCapabilities.TRANSPORT_VPN).toIntArray()
+        // Put the underlying network back and check that the underlying transport reappears.
+        val expectedTransports = (defaultNetworkTransports.toSet() + TRANSPORT_TEST + TRANSPORT_VPN)
+                .toIntArray()
         agent.setUnderlyingNetworks(null)
         callback.expectCapabilitiesThat(agent.network!!) {
             it.transportTypes.size == expectedTransports.size &&
                     hasAllTransports(it, expectedTransports)
         }
 
+        // Check that some underlying capabilities are propagated.
+        // This is not very accurate because the test does not control the capabilities of the
+        // underlying networks, and because not congested, not roaming, and not suspended are the
+        // default anyway. It's still useful as an extra check though.
+        vpnNc = mCM.getNetworkCapabilities(agent.network)
+        for (cap in listOf(NET_CAPABILITY_NOT_CONGESTED,
+                NET_CAPABILITY_NOT_ROAMING,
+                NET_CAPABILITY_NOT_SUSPENDED)) {
+            val capStr = valueToString(NetworkCapabilities::class.java, "NET_CAPABILITY_", cap)
+            if (defaultNetworkCapabilities.hasCapability(cap) && !vpnNc.hasCapability(cap)) {
+                fail("$capStr not propagated from underlying: $defaultNetworkCapabilities")
+            }
+        }
+
         agent.unregister()
         callback.expectCallback<Lost>(agent.network)
-
-        mCM.unregisterNetworkCallback(callback)
     }
 
     @Test
@@ -606,7 +619,7 @@
         val mockCm = mock(ConnectivityManager::class.java)
         doReturn(mockCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
         createConnectedNetworkAgent(mockContext)
-        verify(mockCm).registerNetworkAgent(any(Messenger::class.java),
+        verify(mockCm).registerNetworkAgent(any(),
                 argThat<NetworkInfo> { it.detailedState == NetworkInfo.DetailedState.CONNECTING },
                 any(LinkProperties::class.java),
                 any(NetworkCapabilities::class.java),
@@ -618,7 +631,7 @@
     @Test
     fun testSetAcceptUnvalidated() {
         createNetworkAgentWithFakeCS().let { agent ->
-            mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 1)
+            mFakeConnectivityService.agent.onSaveAcceptUnvalidated(true)
             agent.expectCallback<OnSaveAcceptUnvalidated>().let {
                 assertTrue(it.accept)
             }
@@ -629,19 +642,18 @@
     @Test
     fun testSetAcceptUnvalidatedPreventAutomaticReconnect() {
         createNetworkAgentWithFakeCS().let { agent ->
-            mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 0)
-            mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT)
+            mFakeConnectivityService.agent.onSaveAcceptUnvalidated(false)
+            mFakeConnectivityService.agent.onPreventAutomaticReconnect()
             agent.expectCallback<OnSaveAcceptUnvalidated>().let {
                 assertFalse(it.accept)
             }
             agent.expectCallback<OnAutomaticReconnectDisabled>()
             agent.assertNoCallback()
             // When automatic reconnect is turned off, the network is torn down and
-            // ConnectivityService sends a disconnect. This in turn causes the agent
-            // to send a DISCONNECTED message to CS.
-            mFakeConnectivityService.willExpectDisconnectOnce()
+            // ConnectivityService disconnects. As part of the disconnect, ConnectivityService will
+            // also send itself a message to unregister the NetworkAgent from its internal
+            // structure.
             mFakeConnectivityService.disconnect()
-            mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED)
             agent.expectCallback<OnNetworkUnwanted>()
         }
     }
@@ -649,12 +661,10 @@
     @Test
     fun testPreventAutomaticReconnect() {
         createNetworkAgentWithFakeCS().let { agent ->
-            mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT)
+            mFakeConnectivityService.agent.onPreventAutomaticReconnect()
             agent.expectCallback<OnAutomaticReconnectDisabled>()
             agent.assertNoCallback()
-            mFakeConnectivityService.willExpectDisconnectOnce()
             mFakeConnectivityService.disconnect()
-            mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED)
             agent.expectCallback<OnNetworkUnwanted>()
         }
     }
@@ -662,18 +672,14 @@
     @Test
     fun testValidationStatus() = createNetworkAgentWithFakeCS().let { agent ->
         val uri = Uri.parse("http://www.google.com")
-        val bundle = Bundle().apply {
-            putString(NetworkAgent.REDIRECT_URL_KEY, uri.toString())
-        }
-        mFakeConnectivityService.sendMessage(CMD_REPORT_NETWORK_STATUS,
-                arg1 = VALID_NETWORK, obj = bundle)
+        mFakeConnectivityService.agent.onValidationStatusChanged(VALID_NETWORK,
+                uri.toString())
         agent.expectCallback<OnValidationStatus>().let {
             assertEquals(it.status, VALID_NETWORK)
             assertEquals(it.uri, uri)
         }
 
-        mFakeConnectivityService.sendMessage(CMD_REPORT_NETWORK_STATUS,
-                arg1 = INVALID_NETWORK, obj = Bundle())
+        mFakeConnectivityService.agent.onValidationStatusChanged(INVALID_NETWORK, null)
         agent.expectCallback<OnValidationStatus>().let {
             assertEquals(it.status, INVALID_NETWORK)
             assertNull(it.uri)
@@ -687,7 +693,7 @@
         // First create a request to make sure the network is kept up
         val request1 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS).also {
             registerNetworkCallback(request1, it)
@@ -697,7 +703,7 @@
         // Then file the interesting request
         val request = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -708,18 +714,18 @@
 
             // Send TEMP_NOT_METERED and check that the callback is called appropriately.
             val nc1 = NetworkCapabilities(agent.nc)
-                    .addCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                    .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             agent.sendNetworkCapabilities(nc1)
             callback.expectCapabilitiesThat(agent.network) {
-                it.hasCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             }
 
             // Remove TEMP_NOT_METERED and check that the callback is called appropriately.
             val nc2 = NetworkCapabilities(agent.nc)
-                    .removeCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                    .removeCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             agent.sendNetworkCapabilities(nc2)
             callback.expectCapabilitiesThat(agent.network) {
-                !it.hasCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                !it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             }
         }