Caps callback should be called when underlying networks are changed

Add a test to check if onCapabilitiesChanged will be called when
the underlying networks are changed.

Bug: 191918368
Test: atest CtsHostsideNetworkTests:HostsideVpnTests
Change-Id: I8dfb16e01199d41b1b65f69e129ec40e37f9ab0f
diff --git a/tests/cts/hostside/app/AndroidManifest.xml b/tests/cts/hostside/app/AndroidManifest.xml
index e5bae5f..d56e5d4 100644
--- a/tests/cts/hostside/app/AndroidManifest.xml
+++ b/tests/cts/hostside/app/AndroidManifest.xml
@@ -20,6 +20,7 @@
     <uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/>
     <uses-permission android:name="android.permission.ACCESS_WIFI_STATE"/>
     <uses-permission android:name="android.permission.CHANGE_WIFI_STATE"/>
+    <uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" />
     <uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
     <uses-permission android:name="android.permission.INTERNET"/>
     <uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION"/>
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyVpnService.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyVpnService.java
index 7d3d4fc..855abf7 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyVpnService.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyVpnService.java
@@ -17,11 +17,11 @@
 package com.android.cts.net.hostside;
 
 import android.content.Intent;
+import android.content.pm.PackageManager.NameNotFoundException;
 import android.net.Network;
 import android.net.ProxyInfo;
 import android.net.VpnService;
 import android.os.ParcelFileDescriptor;
-import android.content.pm.PackageManager.NameNotFoundException;
 import android.text.TextUtils;
 import android.util.Log;
 
@@ -38,6 +38,9 @@
     public static final String ACTION_ESTABLISHED = "com.android.cts.net.hostside.ESTABNLISHED";
     public static final String EXTRA_ALWAYS_ON = "is-always-on";
     public static final String EXTRA_LOCKDOWN_ENABLED = "is-lockdown-enabled";
+    public static final String CMD_CONNECT = "connect";
+    public static final String CMD_DISCONNECT = "disconnect";
+    public static final String CMD_UPDATE_UNDERLYING_NETWORKS = "update_underlying_networks";
 
     private ParcelFileDescriptor mFd = null;
     private PacketReflector mPacketReflector = null;
@@ -46,15 +49,24 @@
     public int onStartCommand(Intent intent, int flags, int startId) {
         String packageName = getPackageName();
         String cmd = intent.getStringExtra(packageName + ".cmd");
-        if ("disconnect".equals(cmd)) {
+        if (CMD_DISCONNECT.equals(cmd)) {
             stop();
-        } else if ("connect".equals(cmd)) {
+        } else if (CMD_CONNECT.equals(cmd)) {
             start(packageName, intent);
+        } else if (CMD_UPDATE_UNDERLYING_NETWORKS.equals(cmd)) {
+            updateUnderlyingNetworks(packageName, intent);
         }
 
         return START_NOT_STICKY;
     }
 
+    private void updateUnderlyingNetworks(String packageName, Intent intent) {
+        final ArrayList<Network> underlyingNetworks =
+                intent.getParcelableArrayListExtra(packageName + ".underlyingNetworks");
+        setUnderlyingNetworks(
+                (underlyingNetworks != null) ? underlyingNetworks.toArray(new Network[0]) : null);
+    }
+
     private void start(String packageName, Intent intent) {
         Builder builder = new Builder();
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 3abc4fb..65ea4b1 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -17,6 +17,8 @@
 package com.android.cts.net.hostside;
 
 import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.content.pm.PackageManager.FEATURE_TELEPHONY;
+import static android.content.pm.PackageManager.FEATURE_WIFI;
 import static android.net.ConnectivityManager.TYPE_VPN;
 import static android.net.NetworkCapabilities.TRANSPORT_VPN;
 import static android.os.Process.INVALID_UID;
@@ -33,12 +35,14 @@
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
 import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
+import static com.android.testutils.Cleanup.testAndCleanup;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeTrue;
 
 import android.annotation.Nullable;
 import android.app.Activity;
@@ -65,7 +69,9 @@
 import android.net.VpnManager;
 import android.net.VpnService;
 import android.net.VpnTransportInfo;
+import android.net.cts.util.CtsNetUtils;
 import android.net.wifi.WifiManager;
+import android.os.Build;
 import android.os.Handler;
 import android.os.Looper;
 import android.os.ParcelFileDescriptor;
@@ -80,6 +86,7 @@
 import android.system.Os;
 import android.system.OsConstants;
 import android.system.StructPollfd;
+import android.telephony.TelephonyManager;
 import android.test.MoreAsserts;
 import android.text.TextUtils;
 import android.util.Log;
@@ -88,10 +95,14 @@
 
 import com.android.compatibility.common.util.BlockingBroadcastReceiver;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
+import com.android.testutils.RecorderCallback;
 import com.android.testutils.TestableNetworkCallback;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -151,6 +162,7 @@
     private static final String PRIVATE_DNS_MODE_PROVIDER_HOSTNAME = "hostname";
     private static final String PRIVATE_DNS_MODE_OPPORTUNISTIC = "opportunistic";
     private static final String PRIVATE_DNS_SPECIFIER_SETTING = "private_dns_specifier";
+    private static final int NETWORK_CALLBACK_TIMEOUT_MS = 30_000;
 
     public static String TAG = "VpnTest";
     public static int TIMEOUT_MS = 3 * 1000;
@@ -163,6 +175,9 @@
     private ConnectivityManager mCM;
     private WifiManager mWifiManager;
     private RemoteSocketFactoryClient mRemoteSocketFactoryClient;
+    private CtsNetUtils mCtsNetUtils;
+    private PackageManager mPackageManager;
+    private TelephonyManager mTelephonyManager;
 
     Network mNetwork;
     NetworkCallback mCallback;
@@ -172,6 +187,9 @@
     private String mOldPrivateDnsMode;
     private String mOldPrivateDnsSpecifier;
 
+    @Rule
+    public final DevSdkIgnoreRule mDevSdkIgnoreRule = new DevSdkIgnoreRule();
+
     private boolean supportedHardware() {
         final PackageManager pm = getInstrumentation().getContext().getPackageManager();
         return !pm.hasSystemFeature("android.hardware.type.watch");
@@ -201,6 +219,10 @@
         mRemoteSocketFactoryClient = new RemoteSocketFactoryClient(mActivity);
         mRemoteSocketFactoryClient.bind();
         mDevice.waitForIdle();
+        mCtsNetUtils = new CtsNetUtils(getInstrumentation().getContext());
+        mPackageManager = getInstrumentation().getContext().getPackageManager();
+        mTelephonyManager =
+                getInstrumentation().getContext().getSystemService(TelephonyManager.class);
     }
 
     @After
@@ -210,6 +232,7 @@
         if (mCallback != null) {
             mCM.unregisterNetworkCallback(mCallback);
         }
+        mCtsNetUtils.tearDown();
         Log.i(TAG, "Stopping VPN");
         stopVpn();
         mActivity.finish();
@@ -266,6 +289,32 @@
         }
     }
 
+    private void updateUnderlyingNetworks(@Nullable ArrayList<Network> underlyingNetworks)
+            throws Exception {
+        final Intent intent = new Intent(mActivity, MyVpnService.class)
+                .putExtra(mPackageName + ".cmd", MyVpnService.CMD_UPDATE_UNDERLYING_NETWORKS)
+                .putParcelableArrayListExtra(
+                        mPackageName + ".underlyingNetworks", underlyingNetworks);
+        mActivity.startService(intent);
+    }
+
+    private void establishVpn(String[] addresses, String[] routes, String allowedApplications,
+            String disallowedApplications, @Nullable ProxyInfo proxyInfo,
+            @Nullable ArrayList<Network> underlyingNetworks, boolean isAlwaysMetered)
+            throws Exception {
+        final Intent intent = new Intent(mActivity, MyVpnService.class)
+                .putExtra(mPackageName + ".cmd", MyVpnService.CMD_CONNECT)
+                .putExtra(mPackageName + ".addresses", TextUtils.join(",", addresses))
+                .putExtra(mPackageName + ".routes", TextUtils.join(",", routes))
+                .putExtra(mPackageName + ".allowedapplications", allowedApplications)
+                .putExtra(mPackageName + ".disallowedapplications", disallowedApplications)
+                .putExtra(mPackageName + ".httpProxy", proxyInfo)
+                .putParcelableArrayListExtra(
+                        mPackageName + ".underlyingNetworks", underlyingNetworks)
+                .putExtra(mPackageName + ".isAlwaysMetered", isAlwaysMetered);
+        mActivity.startService(intent);
+    }
+
     // TODO: Consider replacing arguments with a Builder.
     private void startVpn(
         String[] addresses, String[] routes, String allowedApplications,
@@ -291,18 +340,8 @@
         mCM.registerNetworkCallback(request, mCallback);  // Unregistered in tearDown.
 
         // Start the service and wait up for TIMEOUT_MS ms for the VPN to come up.
-        Intent intent = new Intent(mActivity, MyVpnService.class)
-                .putExtra(mPackageName + ".cmd", "connect")
-                .putExtra(mPackageName + ".addresses", TextUtils.join(",", addresses))
-                .putExtra(mPackageName + ".routes", TextUtils.join(",", routes))
-                .putExtra(mPackageName + ".allowedapplications", allowedApplications)
-                .putExtra(mPackageName + ".disallowedapplications", disallowedApplications)
-                .putExtra(mPackageName + ".httpProxy", proxyInfo)
-                .putParcelableArrayListExtra(
-                    mPackageName + ".underlyingNetworks", underlyingNetworks)
-                .putExtra(mPackageName + ".isAlwaysMetered", isAlwaysMetered);
-
-        mActivity.startService(intent);
+        establishVpn(addresses, routes, allowedApplications, disallowedApplications, proxyInfo,
+                underlyingNetworks, isAlwaysMetered);
         synchronized (mLock) {
             if (mNetwork == null) {
                  Log.i(TAG, "bf mLock");
@@ -344,7 +383,7 @@
         // and stopping a bound service has no effect. Instead, "start" the service again with an
         // Intent that tells it to disconnect.
         Intent intent = new Intent(mActivity, MyVpnService.class)
-                .putExtra(mPackageName + ".cmd", "disconnect");
+                .putExtra(mPackageName + ".cmd", MyVpnService.CMD_DISCONNECT);
         mActivity.startService(intent);
         synchronized (mLockShutdown) {
             try {
@@ -724,6 +763,83 @@
         setAndVerifyPrivateDns(initialMode);
     }
 
+    private NetworkRequest makeVpnNetworkRequest() {
+        return new NetworkRequest.Builder()
+                .addTransportType(NetworkCapabilities.TRANSPORT_VPN)
+                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+                .build();
+    }
+
+    private void expectUnderlyingNetworks(TestableNetworkCallback callback,
+            @Nullable List<Network> expectUnderlyingNetworks) {
+        callback.eventuallyExpect(RecorderCallback.CallbackEntry.NETWORK_CAPS_UPDATED,
+                NETWORK_CALLBACK_TIMEOUT_MS,
+                entry -> (Objects.equals(expectUnderlyingNetworks,
+                        ((RecorderCallback.CallbackEntry.CapabilitiesChanged) entry)
+                                .getCaps().getUnderlyingNetworks())));
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.S)
+    public void testChangeUnderlyingNetworks() throws Exception {
+        assumeTrue(supportedHardware());
+        assumeTrue(mPackageManager.hasSystemFeature(FEATURE_WIFI));
+        assumeTrue(mPackageManager.hasSystemFeature(FEATURE_TELEPHONY));
+        final TestableNetworkCallback callback = new TestableNetworkCallback();
+        final boolean isWifiEnabled = mWifiManager.isWifiEnabled();
+        testAndCleanup(() -> {
+            // Ensure both of wifi and mobile data are connected.
+            final Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
+            assertTrue("Wifi is not connected", (wifiNetwork != null));
+            final Network cellNetwork = mCtsNetUtils.connectToCell();
+            assertTrue("Mobile data is not connected", (cellNetwork != null));
+            // Store current default network.
+            final Network defaultNetwork = mCM.getActiveNetwork();
+            // Start VPN and set empty array as its underlying networks.
+            startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"} /* addresses */,
+                    new String[] {"0.0.0.0/0", "::/0"} /* routes */,
+                    "" /* allowedApplications */, "" /* disallowedApplications */,
+                    null /* proxyInfo */, new ArrayList<>() /* underlyingNetworks */,
+                    false /* isAlwaysMetered */);
+            // Acquire the NETWORK_SETTINGS permission for getting the underlying networks.
+            runWithShellPermissionIdentity(() -> {
+                mCM.registerNetworkCallback(makeVpnNetworkRequest(), callback);
+                // Check that this VPN doesn't have any underlying networks.
+                expectUnderlyingNetworks(callback, new ArrayList<Network>());
+
+                // Update the underlying networks to null and the underlying networks should follow
+                // the system default network.
+                updateUnderlyingNetworks(null);
+                expectUnderlyingNetworks(callback, List.of(defaultNetwork));
+
+                // Update the underlying networks to mobile data.
+                updateUnderlyingNetworks(new ArrayList<>(List.of(cellNetwork)));
+                // Check the underlying networks of NetworkCapabilities which comes from
+                // onCapabilitiesChanged is mobile data.
+                expectUnderlyingNetworks(callback, List.of(cellNetwork));
+
+                // Update the underlying networks to wifi.
+                updateUnderlyingNetworks(new ArrayList<>(List.of(wifiNetwork)));
+                // Check the underlying networks of NetworkCapabilities which comes from
+                // onCapabilitiesChanged is wifi.
+                expectUnderlyingNetworks(callback, List.of(wifiNetwork));
+
+                // Update the underlying networks to wifi and mobile data.
+                updateUnderlyingNetworks(new ArrayList<>(List.of(wifiNetwork, cellNetwork)));
+                // Check the underlying networks of NetworkCapabilities which comes from
+                // onCapabilitiesChanged is wifi and mobile data.
+                expectUnderlyingNetworks(callback, List.of(wifiNetwork, cellNetwork));
+            }, NETWORK_SETTINGS);
+        }, () -> {
+                if (isWifiEnabled) {
+                    mCtsNetUtils.ensureWifiConnected();
+                } else {
+                    mCtsNetUtils.ensureWifiDisconnected(null);
+                }
+            }, () -> {
+                mCM.unregisterNetworkCallback(callback);
+            });
+    }
+
     @Test
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
index 49b5f9d..180015d 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
@@ -33,6 +33,10 @@
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
+    public void testChangeUnderlyingNetworks() throws Exception {
+        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testChangeUnderlyingNetworks");
+    }
+
     public void testDefault() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testDefault");
     }
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index ce873f7..7254319 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -380,6 +380,12 @@
         return mCellNetworkCallback != null;
     }
 
+    public void tearDown() {
+        if (cellConnectAttempted()) {
+            disconnectFromCell();
+        }
+    }
+
     private NetworkRequest makeWifiNetworkRequest() {
         return new NetworkRequest.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)