Merge "Add back tethering to mainline-postsubmit" into sc-dev
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index e496966..758c612 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -4722,6 +4722,22 @@
     }
 
     /**
+     * Temporarily allow bad wifi to override {@code config_networkAvoidBadWifi} configuration.
+     *
+     * @param timeMs The expired current time. The value should be set within a limited time from
+     *               now.
+     *
+     * @hide
+     */
+    public void setTestAllowBadWifiUntil(long timeMs) {
+        try {
+            mService.setTestAllowBadWifiUntil(timeMs);
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
+    /**
      * Requests that the system open the captive portal app on the specified network.
      *
      * <p>This is to be used on networks where a captive portal was detected, as per
diff --git a/framework/src/android/net/ConnectivitySettingsManager.java b/framework/src/android/net/ConnectivitySettingsManager.java
index 4644e4f..085de6b 100644
--- a/framework/src/android/net/ConnectivitySettingsManager.java
+++ b/framework/src/android/net/ConnectivitySettingsManager.java
@@ -562,7 +562,7 @@
     public static void setNetworkSwitchNotificationMaximumDailyCount(@NonNull Context context,
             @IntRange(from = 0) int count) {
         if (count < 0) {
-            throw new IllegalArgumentException("Count must be 0~10.");
+            throw new IllegalArgumentException("Count must be more than 0.");
         }
         Settings.Global.putInt(
                 context.getContentResolver(), NETWORK_SWITCH_NOTIFICATION_DAILY_LIMIT, count);
@@ -585,6 +585,7 @@
 
     /**
      * Set minimum duration (to {@link Settings}) between each switching network notifications.
+     * The duration will be rounded down to the next millisecond, and must be positive.
      *
      * @param context The {@link Context} to set the setting.
      * @param duration The minimum duration between notifications when switching networks.
@@ -612,10 +613,11 @@
 
     /**
      * Set URL (to {@link Settings}) used for HTTP captive portal detection upon a new connection.
-     * This URL should respond with a 204 response to a GET request to indicate no captive portal is
-     * present. And this URL must be HTTP as redirect responses are used to find captive portal
-     * sign-in pages. If the URL set to null or be incorrect, it will result in captive portal
-     * detection failed and lost the connection.
+     * The URL is accessed to check for connectivity and presence of a captive portal on a network.
+     * The URL should respond with HTTP status 204 to a GET request, and the stack will use
+     * redirection status as a signal for captive portal detection.
+     * If the URL is set to null or is otherwise incorrect or inaccessible, the stack will fail to
+     * detect connectivity and portals. This will often result in loss of connectivity.
      *
      * @param context The {@link Context} to set the setting.
      * @param url The URL used for HTTP captive portal detection upon a new connection.
@@ -819,6 +821,7 @@
 
     /**
      * Set duration (to {@link Settings}) to keep a PendingIntent-based request.
+     * The duration will be rounded down to the next millisecond, and must be positive.
      *
      * @param context The {@link Context} to set the setting.
      * @param duration The duration to keep a PendingIntent-based request.
diff --git a/framework/src/android/net/IConnectivityManager.aidl b/framework/src/android/net/IConnectivityManager.aidl
index c434bbc..50ec781 100644
--- a/framework/src/android/net/IConnectivityManager.aidl
+++ b/framework/src/android/net/IConnectivityManager.aidl
@@ -226,4 +226,6 @@
     void offerNetwork(int providerId, in NetworkScore score,
             in NetworkCapabilities caps, in INetworkOfferCallback callback);
     void unofferNetwork(in INetworkOfferCallback callback);
+
+    void setTestAllowBadWifiUntil(long timeMs);
 }
diff --git a/framework/src/android/net/util/MultinetworkPolicyTracker.java b/framework/src/android/net/util/MultinetworkPolicyTracker.java
index 0b42a00..7e62d28 100644
--- a/framework/src/android/net/util/MultinetworkPolicyTracker.java
+++ b/framework/src/android/net/util/MultinetworkPolicyTracker.java
@@ -75,6 +75,7 @@
     private volatile boolean mAvoidBadWifi = true;
     private volatile int mMeteredMultipathPreference;
     private int mActiveSubId = SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+    private volatile long mTestAllowBadWifiUntilMs = 0;
 
     // Mainline module can't use internal HandlerExecutor, so add an identical executor here.
     private static class HandlerExecutor implements Executor {
@@ -162,14 +163,31 @@
      * Whether the device or carrier configuration disables avoiding bad wifi by default.
      */
     public boolean configRestrictsAvoidBadWifi() {
+        final boolean allowBadWifi = mTestAllowBadWifiUntilMs > 0
+                && mTestAllowBadWifiUntilMs > System.currentTimeMillis();
+        // If the config returns true, then avoid bad wifi design can be controlled by the
+        // NETWORK_AVOID_BAD_WIFI setting.
+        if (allowBadWifi) return true;
+
         // TODO: use R.integer.config_networkAvoidBadWifi directly
         final int id = mResources.get().getIdentifier("config_networkAvoidBadWifi",
                 "integer", mResources.getResourcesContext().getPackageName());
         return (getResourcesForActiveSubId().getInteger(id) == 0);
     }
 
+    /**
+     * Temporarily allow bad wifi to override {@code config_networkAvoidBadWifi} configuration.
+     * The value works when the time set is more than {@link System.currentTimeMillis()}.
+     */
+    public void setTestAllowBadWifiUntil(long timeMs) {
+        Log.d(TAG, "setTestAllowBadWifiUntil: " + mTestAllowBadWifiUntilMs);
+        mTestAllowBadWifiUntilMs = timeMs;
+        updateAvoidBadWifi();
+    }
+
+    @VisibleForTesting
     @NonNull
-    private Resources getResourcesForActiveSubId() {
+    protected Resources getResourcesForActiveSubId() {
         return SubscriptionManager.getResourcesForSubId(
                 mResources.getResourcesContext(), mActiveSubId);
     }
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 899286b..1f91eca 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -652,6 +652,12 @@
     private static final int EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED = 54;
 
     /**
+     * Event to set temporary allow bad wifi within a limited time to override
+     * {@code config_networkAvoidBadWifi}.
+     */
+    private static final int EVENT_SET_TEST_ALLOW_BAD_WIFI_UNTIL = 55;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -663,6 +669,11 @@
      */
     private static final int PROVISIONING_NOTIFICATION_HIDE = 0;
 
+    /**
+     * The maximum alive time to allow bad wifi configuration for testing.
+     */
+    private static final long MAX_TEST_ALLOW_BAD_WIFI_UNTIL_MS = 5 * 60 * 1000L;
+
     private static String eventName(int what) {
         return sMagicDecoderRing.get(what, Integer.toString(what));
     }
@@ -1276,6 +1287,20 @@
         public boolean getCellular464XlatEnabled() {
             return NetworkProperties.isCellular464XlatEnabled().orElse(true);
         }
+
+        /**
+         * @see PendingIntent#intentFilterEquals
+         */
+        public boolean intentFilterEquals(PendingIntent a, PendingIntent b) {
+            return a.intentFilterEquals(b);
+        }
+
+        /**
+         * @see LocationPermissionChecker
+         */
+        public LocationPermissionChecker makeLocationPermissionChecker(Context context) {
+            return new LocationPermissionChecker(context);
+        }
     }
 
     public ConnectivityService(Context context) {
@@ -1343,7 +1368,7 @@
         mNetd = netd;
         mTelephonyManager = (TelephonyManager) mContext.getSystemService(Context.TELEPHONY_SERVICE);
         mAppOpsManager = (AppOpsManager) mContext.getSystemService(Context.APP_OPS_SERVICE);
-        mLocationPermissionChecker = new LocationPermissionChecker(mContext);
+        mLocationPermissionChecker = mDeps.makeLocationPermissionChecker(mContext);
 
         // To ensure uid state is synchronized with Network Policy, register for
         // NetworkPolicyManagerService events must happen prior to NetworkPolicyManagerService
@@ -3934,7 +3959,7 @@
         for (Map.Entry<NetworkRequest, NetworkRequestInfo> entry : mNetworkRequests.entrySet()) {
             PendingIntent existingPendingIntent = entry.getValue().mPendingIntent;
             if (existingPendingIntent != null &&
-                    existingPendingIntent.intentFilterEquals(pendingIntent)) {
+                    mDeps.intentFilterEquals(existingPendingIntent, pendingIntent)) {
                 return entry.getValue();
             }
         }
@@ -4328,6 +4353,22 @@
         mHandler.sendMessage(mHandler.obtainMessage(EVENT_SET_AVOID_UNVALIDATED, network));
     }
 
+    @Override
+    public void setTestAllowBadWifiUntil(long timeMs) {
+        enforceSettingsPermission();
+        if (!Build.isDebuggable()) {
+            throw new IllegalStateException("Does not support in non-debuggable build");
+        }
+
+        if (timeMs > System.currentTimeMillis() + MAX_TEST_ALLOW_BAD_WIFI_UNTIL_MS) {
+            throw new IllegalArgumentException("It should not exceed "
+                    + MAX_TEST_ALLOW_BAD_WIFI_UNTIL_MS + "ms from now");
+        }
+
+        mHandler.sendMessage(
+                mHandler.obtainMessage(EVENT_SET_TEST_ALLOW_BAD_WIFI_UNTIL, timeMs));
+    }
+
     private void handleSetAcceptUnvalidated(Network network, boolean accept, boolean always) {
         if (DBG) log("handleSetAcceptUnvalidated network=" + network +
                 " accept=" + accept + " always=" + always);
@@ -4870,6 +4911,10 @@
                 case EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED:
                     handleMobileDataPreferredUidsChanged();
                     break;
+                case EVENT_SET_TEST_ALLOW_BAD_WIFI_UNTIL:
+                    final long timeMs = ((Long) msg.obj).longValue();
+                    mMultinetworkPolicyTracker.setTestAllowBadWifiUntil(timeMs);
+                    break;
             }
         }
     }
diff --git a/tests/unit/java/android/net/ConnectivityDiagnosticsManagerTest.java b/tests/common/java/android/net/ConnectivityDiagnosticsManagerTest.java
similarity index 98%
rename from tests/unit/java/android/net/ConnectivityDiagnosticsManagerTest.java
rename to tests/common/java/android/net/ConnectivityDiagnosticsManagerTest.java
index 06e9405..294ed10 100644
--- a/tests/unit/java/android/net/ConnectivityDiagnosticsManagerTest.java
+++ b/tests/common/java/android/net/ConnectivityDiagnosticsManagerTest.java
@@ -36,9 +36,11 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.content.Context;
+import android.os.Build;
 import android.os.PersistableBundle;
 
 import androidx.test.InstrumentationRegistry;
+import androidx.test.filters.SdkSuppress;
 
 import org.junit.After;
 import org.junit.Before;
@@ -50,6 +52,7 @@
 import java.util.concurrent.Executor;
 
 @RunWith(JUnit4.class)
+@SdkSuppress(minSdkVersion = Build.VERSION_CODES.S, codeName = "S")
 public class ConnectivityDiagnosticsManagerTest {
     private static final int NET_ID = 1;
     private static final int DETECTION_METHOD = 2;
diff --git a/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt b/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt
new file mode 100644
index 0000000..ebaa787
--- /dev/null
+++ b/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt
@@ -0,0 +1,295 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.net.ConnectivitySettingsManager.CAPTIVE_PORTAL_MODE
+import android.net.ConnectivitySettingsManager.CAPTIVE_PORTAL_MODE_AVOID
+import android.net.ConnectivitySettingsManager.CAPTIVE_PORTAL_MODE_IGNORE
+import android.net.ConnectivitySettingsManager.CAPTIVE_PORTAL_MODE_PROMPT
+import android.net.ConnectivitySettingsManager.CONNECTIVITY_RELEASE_PENDING_INTENT_DELAY_MS
+import android.net.ConnectivitySettingsManager.DATA_ACTIVITY_TIMEOUT_MOBILE
+import android.net.ConnectivitySettingsManager.DATA_ACTIVITY_TIMEOUT_WIFI
+import android.net.ConnectivitySettingsManager.DNS_RESOLVER_MAX_SAMPLES
+import android.net.ConnectivitySettingsManager.DNS_RESOLVER_MIN_SAMPLES
+import android.net.ConnectivitySettingsManager.DNS_RESOLVER_SAMPLE_VALIDITY_SECONDS
+import android.net.ConnectivitySettingsManager.DNS_RESOLVER_SUCCESS_THRESHOLD_PERCENT
+import android.net.ConnectivitySettingsManager.MOBILE_DATA_ALWAYS_ON
+import android.net.ConnectivitySettingsManager.NETWORK_SWITCH_NOTIFICATION_DAILY_LIMIT
+import android.net.ConnectivitySettingsManager.NETWORK_SWITCH_NOTIFICATION_RATE_LIMIT_MILLIS
+import android.net.ConnectivitySettingsManager.PRIVATE_DNS_DEFAULT_MODE
+import android.net.ConnectivitySettingsManager.PRIVATE_DNS_MODE_OFF
+import android.net.ConnectivitySettingsManager.PRIVATE_DNS_MODE_OPPORTUNISTIC
+import android.net.ConnectivitySettingsManager.WIFI_ALWAYS_REQUESTED
+import android.net.ConnectivitySettingsManager.getCaptivePortalMode
+import android.net.ConnectivitySettingsManager.getConnectivityKeepPendingIntentDuration
+import android.net.ConnectivitySettingsManager.getDnsResolverSampleRanges
+import android.net.ConnectivitySettingsManager.getDnsResolverSampleValidityDuration
+import android.net.ConnectivitySettingsManager.getDnsResolverSuccessThresholdPercent
+import android.net.ConnectivitySettingsManager.getMobileDataActivityTimeout
+import android.net.ConnectivitySettingsManager.getMobileDataAlwaysOn
+import android.net.ConnectivitySettingsManager.getNetworkSwitchNotificationMaximumDailyCount
+import android.net.ConnectivitySettingsManager.getNetworkSwitchNotificationRateDuration
+import android.net.ConnectivitySettingsManager.getPrivateDnsDefaultMode
+import android.net.ConnectivitySettingsManager.getWifiAlwaysRequested
+import android.net.ConnectivitySettingsManager.getWifiDataActivityTimeout
+import android.net.ConnectivitySettingsManager.setCaptivePortalMode
+import android.net.ConnectivitySettingsManager.setConnectivityKeepPendingIntentDuration
+import android.net.ConnectivitySettingsManager.setDnsResolverSampleRanges
+import android.net.ConnectivitySettingsManager.setDnsResolverSampleValidityDuration
+import android.net.ConnectivitySettingsManager.setDnsResolverSuccessThresholdPercent
+import android.net.ConnectivitySettingsManager.setMobileDataActivityTimeout
+import android.net.ConnectivitySettingsManager.setMobileDataAlwaysOn
+import android.net.ConnectivitySettingsManager.setNetworkSwitchNotificationMaximumDailyCount
+import android.net.ConnectivitySettingsManager.setNetworkSwitchNotificationRateDuration
+import android.net.ConnectivitySettingsManager.setPrivateDnsDefaultMode
+import android.net.ConnectivitySettingsManager.setWifiAlwaysRequested
+import android.net.ConnectivitySettingsManager.setWifiDataActivityTimeout
+import android.os.Build
+import android.platform.test.annotations.AppModeFull
+import android.provider.Settings
+import android.util.Range
+import androidx.test.InstrumentationRegistry
+import androidx.test.filters.SmallTest
+import com.android.net.module.util.ConnectivitySettingsUtils.getPrivateDnsModeAsString
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import junit.framework.Assert.assertEquals
+import org.junit.Test
+import org.junit.runner.RunWith
+import java.time.Duration
+import java.util.Objects
+import kotlin.test.assertFailsWith
+
+/**
+ * Tests for [ConnectivitySettingsManager].
+ *
+ * Build, install and run with:
+ * atest android.net.ConnectivitySettingsManagerTest
+ */
+@RunWith(DevSdkIgnoreRunner::class)
+@IgnoreUpTo(Build.VERSION_CODES.R)
+@SmallTest
+@AppModeFull(reason = "WRITE_SECURE_SETTINGS permission can't be granted to instant apps")
+class ConnectivitySettingsManagerTest {
+    private val instrumentation = InstrumentationRegistry.getInstrumentation()
+    private val context = instrumentation.context
+    private val resolver = context.contentResolver
+
+    private val defaultDuration = Duration.ofSeconds(0L)
+    private val testTime1 = 5L
+    private val testTime2 = 10L
+    private val settingsTypeGlobal = "global"
+    private val settingsTypeSecure = "secure"
+
+    /*** Reset setting value or delete setting if the setting was not existed before testing. */
+    private fun resetSettings(names: Array<String>, type: String, values: Array<String?>) {
+        for (i in names.indices) {
+            if (Objects.equals(values[i], null)) {
+                instrumentation.uiAutomation.executeShellCommand(
+                        "settings delete $type ${names[i]}")
+            } else {
+                if (settingsTypeSecure.equals(type)) {
+                    Settings.Secure.putString(resolver, names[i], values[i])
+                } else {
+                    Settings.Global.putString(resolver, names[i], values[i])
+                }
+            }
+        }
+    }
+
+    fun <T> testIntSetting(
+        names: Array<String>,
+        type: String,
+        value1: T,
+        value2: T,
+        getter: () -> T,
+        setter: (value: T) -> Unit,
+        testIntValues: IntArray
+    ) {
+        val originals: Array<String?> = Array(names.size) { i ->
+            if (settingsTypeSecure.equals(type)) {
+                Settings.Secure.getString(resolver, names[i])
+            } else {
+                Settings.Global.getString(resolver, names[i])
+            }
+        }
+
+        try {
+            for (i in names.indices) {
+                if (settingsTypeSecure.equals(type)) {
+                    Settings.Secure.putString(resolver, names[i], testIntValues[i].toString())
+                } else {
+                    Settings.Global.putString(resolver, names[i], testIntValues[i].toString())
+                }
+            }
+            assertEquals(value1, getter())
+
+            setter(value2)
+            assertEquals(value2, getter())
+        } finally {
+            resetSettings(names, type, originals)
+        }
+    }
+
+    @Test
+    fun testMobileDataActivityTimeout() {
+        testIntSetting(names = arrayOf(DATA_ACTIVITY_TIMEOUT_MOBILE), type = settingsTypeGlobal,
+                value1 = Duration.ofSeconds(testTime1), value2 = Duration.ofSeconds(testTime2),
+                getter = { getMobileDataActivityTimeout(context, defaultDuration) },
+                setter = { setMobileDataActivityTimeout(context, it) },
+                testIntValues = intArrayOf(testTime1.toInt()))
+    }
+
+    @Test
+    fun testWifiDataActivityTimeout() {
+        testIntSetting(names = arrayOf(DATA_ACTIVITY_TIMEOUT_WIFI), type = settingsTypeGlobal,
+                value1 = Duration.ofSeconds(testTime1), value2 = Duration.ofSeconds(testTime2),
+                getter = { getWifiDataActivityTimeout(context, defaultDuration) },
+                setter = { setWifiDataActivityTimeout(context, it) },
+                testIntValues = intArrayOf(testTime1.toInt()))
+    }
+
+    @Test
+    fun testDnsResolverSampleValidityDuration() {
+        testIntSetting(names = arrayOf(DNS_RESOLVER_SAMPLE_VALIDITY_SECONDS),
+                type = settingsTypeGlobal, value1 = Duration.ofSeconds(testTime1),
+                value2 = Duration.ofSeconds(testTime2),
+                getter = { getDnsResolverSampleValidityDuration(context, defaultDuration) },
+                setter = { setDnsResolverSampleValidityDuration(context, it) },
+                testIntValues = intArrayOf(testTime1.toInt()))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setDnsResolverSampleValidityDuration(context, Duration.ofSeconds(-1L)) }
+    }
+
+    @Test
+    fun testDnsResolverSuccessThresholdPercent() {
+        testIntSetting(names = arrayOf(DNS_RESOLVER_SUCCESS_THRESHOLD_PERCENT),
+                type = settingsTypeGlobal, value1 = 5, value2 = 10,
+                getter = { getDnsResolverSuccessThresholdPercent(context, 0 /* def */) },
+                setter = { setDnsResolverSuccessThresholdPercent(context, it) },
+                testIntValues = intArrayOf(5))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setDnsResolverSuccessThresholdPercent(context, -1) }
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setDnsResolverSuccessThresholdPercent(context, 120) }
+    }
+
+    @Test
+    fun testDnsResolverSampleRanges() {
+        testIntSetting(names = arrayOf(DNS_RESOLVER_MIN_SAMPLES, DNS_RESOLVER_MAX_SAMPLES),
+                type = settingsTypeGlobal, value1 = Range(1, 63), value2 = Range(2, 62),
+                getter = { getDnsResolverSampleRanges(context) },
+                setter = { setDnsResolverSampleRanges(context, it) },
+                testIntValues = intArrayOf(1, 63))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setDnsResolverSampleRanges(context, Range(-1, 62)) }
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setDnsResolverSampleRanges(context, Range(2, 65)) }
+    }
+
+    @Test
+    fun testNetworkSwitchNotificationMaximumDailyCount() {
+        testIntSetting(names = arrayOf(NETWORK_SWITCH_NOTIFICATION_DAILY_LIMIT),
+                type = settingsTypeGlobal, value1 = 5, value2 = 15,
+                getter = { getNetworkSwitchNotificationMaximumDailyCount(context, 0 /* def */) },
+                setter = { setNetworkSwitchNotificationMaximumDailyCount(context, it) },
+                testIntValues = intArrayOf(5))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setNetworkSwitchNotificationMaximumDailyCount(context, -1) }
+    }
+
+    @Test
+    fun testNetworkSwitchNotificationRateDuration() {
+        testIntSetting(names = arrayOf(NETWORK_SWITCH_NOTIFICATION_RATE_LIMIT_MILLIS),
+                type = settingsTypeGlobal, value1 = Duration.ofMillis(testTime1),
+                value2 = Duration.ofMillis(testTime2),
+                getter = { getNetworkSwitchNotificationRateDuration(context, defaultDuration) },
+                setter = { setNetworkSwitchNotificationRateDuration(context, it) },
+                testIntValues = intArrayOf(testTime1.toInt()))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setNetworkSwitchNotificationRateDuration(context, Duration.ofMillis(-1L)) }
+    }
+
+    @Test
+    fun testCaptivePortalMode() {
+        testIntSetting(names = arrayOf(CAPTIVE_PORTAL_MODE), type = settingsTypeGlobal,
+                value1 = CAPTIVE_PORTAL_MODE_AVOID, value2 = CAPTIVE_PORTAL_MODE_PROMPT,
+                getter = { getCaptivePortalMode(context, CAPTIVE_PORTAL_MODE_IGNORE) },
+                setter = { setCaptivePortalMode(context, it) },
+                testIntValues = intArrayOf(CAPTIVE_PORTAL_MODE_AVOID))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setCaptivePortalMode(context, 5 /* mode */) }
+    }
+
+    @Test
+    fun testPrivateDnsDefaultMode() {
+        val original = Settings.Global.getString(resolver, PRIVATE_DNS_DEFAULT_MODE)
+
+        try {
+            val mode = getPrivateDnsModeAsString(PRIVATE_DNS_MODE_OPPORTUNISTIC)
+            Settings.Global.putString(resolver, PRIVATE_DNS_DEFAULT_MODE, mode)
+            assertEquals(mode, getPrivateDnsDefaultMode(context))
+
+            setPrivateDnsDefaultMode(context, PRIVATE_DNS_MODE_OFF)
+            assertEquals(getPrivateDnsModeAsString(PRIVATE_DNS_MODE_OFF),
+                    getPrivateDnsDefaultMode(context))
+        } finally {
+            resetSettings(names = arrayOf(PRIVATE_DNS_DEFAULT_MODE), type = settingsTypeGlobal,
+                    values = arrayOf(original))
+        }
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setPrivateDnsDefaultMode(context, -1) }
+    }
+
+    @Test
+    fun testConnectivityKeepPendingIntentDuration() {
+        testIntSetting(names = arrayOf(CONNECTIVITY_RELEASE_PENDING_INTENT_DELAY_MS),
+                type = settingsTypeSecure, value1 = Duration.ofMillis(testTime1),
+                value2 = Duration.ofMillis(testTime2),
+                getter = { getConnectivityKeepPendingIntentDuration(context, defaultDuration) },
+                setter = { setConnectivityKeepPendingIntentDuration(context, it) },
+                testIntValues = intArrayOf(testTime1.toInt()))
+
+        assertFailsWith<IllegalArgumentException>("Expect fail but argument accepted.") {
+            setConnectivityKeepPendingIntentDuration(context, Duration.ofMillis(-1L)) }
+    }
+
+    @Test
+    fun testMobileDataAlwaysOn() {
+        testIntSetting(names = arrayOf(MOBILE_DATA_ALWAYS_ON), type = settingsTypeGlobal,
+                value1 = false, value2 = true,
+                getter = { getMobileDataAlwaysOn(context, true /* def */) },
+                setter = { setMobileDataAlwaysOn(context, it) },
+                testIntValues = intArrayOf(0))
+    }
+
+    @Test
+    fun testWifiAlwaysRequested() {
+        testIntSetting(names = arrayOf(WIFI_ALWAYS_REQUESTED), type = settingsTypeGlobal,
+                value1 = false, value2 = true,
+                getter = { getWifiAlwaysRequested(context, true /* def */) },
+                setter = { setWifiAlwaysRequested(context, it) },
+                testIntValues = intArrayOf(0))
+    }
+}
\ No newline at end of file
diff --git a/tests/common/java/android/net/InvalidPacketExceptionTest.kt b/tests/common/java/android/net/InvalidPacketExceptionTest.kt
new file mode 100644
index 0000000..320ac27
--- /dev/null
+++ b/tests/common/java/android/net/InvalidPacketExceptionTest.kt
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.os.Build
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SdkSuppress
+import org.junit.runner.RunWith
+import kotlin.test.Test
+import kotlin.test.assertEquals
+
+@RunWith(AndroidJUnit4::class)
+@SdkSuppress(minSdkVersion = Build.VERSION_CODES.S, codeName = "S")
+class InvalidPacketExceptionTest {
+    @Test
+    fun testConstructor() {
+        assertEquals(123, InvalidPacketException(123).error)
+        assertEquals(0, InvalidPacketException(0).error)
+        assertEquals(-123, InvalidPacketException(-123).error)
+    }
+}
\ No newline at end of file
diff --git a/tests/cts/net/src/android/net/cts/BatteryStatsManagerTest.java b/tests/cts/net/src/android/net/cts/BatteryStatsManagerTest.java
new file mode 100644
index 0000000..a54fd64
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/BatteryStatsManagerTest.java
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.cts;
+
+import static android.Manifest.permission.UPDATE_DEVICE_STATS;
+
+import static androidx.test.InstrumentationRegistry.getContext;
+
+import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
+import static com.android.testutils.MiscAsserts.assertThrows;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.Network;
+import android.net.cts.util.CtsNetUtils;
+import android.os.BatteryStatsManager;
+import android.os.Build;
+import android.os.connectivity.CellularBatteryStats;
+import android.os.connectivity.WifiBatteryStats;
+import android.util.Log;
+
+import androidx.test.runner.AndroidJUnit4;
+
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.SkipPresubmit;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.util.function.Predicate;
+import java.util.function.Supplier;
+
+/**
+ * Test for BatteryStatsManager.
+ */
+@RunWith(AndroidJUnit4.class)
+public class BatteryStatsManagerTest{
+    @Rule
+    public final DevSdkIgnoreRule ignoreRule = new DevSdkIgnoreRule();
+    private static final String TAG = BatteryStatsManagerTest.class.getSimpleName();
+    private static final String TEST_URL = "https://connectivitycheck.gstatic.com/generate_204";
+    // This value should be the same as BatteryStatsManager.BATTERY_STATUS_DISCHARGING.
+    // TODO: Use the constant once it's available in all branches
+    private static final int BATTERY_STATUS_DISCHARGING = 3;
+
+    private Context mContext;
+    private BatteryStatsManager mBsm;
+    private ConnectivityManager mCm;
+    private CtsNetUtils mCtsNetUtils;
+
+    @Before
+    public void setUp() throws Exception {
+        mContext = getContext();
+        mBsm = mContext.getSystemService(BatteryStatsManager.class);
+        mCm = mContext.getSystemService(ConnectivityManager.class);
+        mCtsNetUtils = new CtsNetUtils(mContext);
+    }
+
+    @Test
+    @SkipPresubmit(reason = "Virtual hardware does not support wifi battery stats")
+    public void testReportNetworkInterfaceForTransports() throws Exception {
+        try {
+            final Network cellNetwork = mCtsNetUtils.connectToCell();
+            final URL url = new URL(TEST_URL);
+
+            // Make sure wifi is disabled.
+            mCtsNetUtils.ensureWifiDisconnected(null /* wifiNetworkToCheck */);
+            // Simulate the device being unplugged from charging.
+            executeShellCommand("dumpsys battery unplug");
+            executeShellCommand("dumpsys battery set status " + BATTERY_STATUS_DISCHARGING);
+            executeShellCommand("dumpsys batterystats enable pretend-screen-off");
+
+            // Get cellular battery stats
+            CellularBatteryStats cellularStatsBefore = runAsShell(UPDATE_DEVICE_STATS,
+                    mBsm::getCellularBatteryStats);
+
+            // Generate traffic on cellular network.
+            generateNetworkTraffic(cellNetwork, url);
+
+            // The mobile battery stats are updated when a network stops being the default network.
+            // ConnectivityService will call BatteryStatsManager.reportMobileRadioPowerState when
+            // removing data activity tracking.
+            final Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
+
+            // Check cellular battery stats are updated.
+            runAsShell(UPDATE_DEVICE_STATS,
+                    () -> assertStatsEventually(mBsm::getCellularBatteryStats,
+                        cellularStatsAfter -> cellularBatteryStatsIncreased(
+                        cellularStatsBefore, cellularStatsAfter)));
+
+            WifiBatteryStats wifiStatsBefore = runAsShell(UPDATE_DEVICE_STATS,
+                    mBsm::getWifiBatteryStats);
+
+            // Generate traffic on wifi network.
+            generateNetworkTraffic(wifiNetwork, url);
+            // Wifi battery stats are updated when wifi on.
+            mCtsNetUtils.toggleWifi();
+
+            // Check wifi battery stats are updated.
+            runAsShell(UPDATE_DEVICE_STATS,
+                    () -> assertStatsEventually(mBsm::getWifiBatteryStats,
+                        wifiStatsAfter -> wifiBatteryStatsIncreased(wifiStatsBefore,
+                        wifiStatsAfter)));
+        } finally {
+            // Reset battery settings.
+            executeShellCommand("dumpsys battery reset");
+            executeShellCommand("dumpsys batterystats disable pretend-screen-off");
+        }
+    }
+
+    @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testReportNetworkInterfaceForTransports_throwsSecurityException()
+            throws Exception {
+        Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
+        final String iface = mCm.getLinkProperties(wifiNetwork).getInterfaceName();
+        final int[] transportType = mCm.getNetworkCapabilities(wifiNetwork).getTransportTypes();
+        assertThrows(SecurityException.class,
+                () -> mBsm.reportNetworkInterfaceForTransports(iface, transportType));
+    }
+
+    private void generateNetworkTraffic(Network network, URL url) throws IOException {
+        HttpURLConnection connection = null;
+        try {
+            connection = (HttpURLConnection) network.openConnection(url);
+            assertEquals(204, connection.getResponseCode());
+        } catch (IOException e) {
+            Log.e(TAG, "Generate traffic failed with exception " + e);
+        } finally {
+            if (connection != null) {
+                connection.disconnect();
+            }
+        }
+    }
+
+    private static <T> void assertStatsEventually(Supplier<T> statsGetter,
+            Predicate<T> statsChecker) throws Exception {
+        // Wait for updating mobile/wifi stats, and check stats every 10ms.
+        final int maxTries = 1000;
+        T result = null;
+        for (int i = 1; i <= maxTries; i++) {
+            result = statsGetter.get();
+            if (statsChecker.test(result)) return;
+            Thread.sleep(10);
+        }
+        final String stats = result instanceof CellularBatteryStats
+                ? "Cellular" : "Wifi";
+        fail(stats + " battery stats did not increase.");
+    }
+
+    private static boolean cellularBatteryStatsIncreased(CellularBatteryStats before,
+            CellularBatteryStats after) {
+        return (after.getNumBytesTx() > before.getNumBytesTx())
+                && (after.getNumBytesRx() > before.getNumBytesRx())
+                && (after.getNumPacketsTx() > before.getNumPacketsTx())
+                && (after.getNumPacketsRx() > before.getNumPacketsRx());
+    }
+
+    private static boolean wifiBatteryStatsIncreased(WifiBatteryStats before,
+            WifiBatteryStats after) {
+        return (after.getNumBytesTx() > before.getNumBytesTx())
+                && (after.getNumBytesRx() > before.getNumBytesRx())
+                && (after.getNumPacketsTx() > before.getNumPacketsTx())
+                && (after.getNumPacketsRx() > before.getNumPacketsRx());
+    }
+
+    private static String executeShellCommand(String command) {
+        final String result = runShellCommand(command).trim();
+        Log.d(TAG, "Output of '" + command + "': '" + result + "'");
+        return result;
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index c2613c8..4c967b8 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -29,6 +29,8 @@
 import static android.content.pm.PackageManager.FEATURE_WIFI_DIRECT;
 import static android.content.pm.PackageManager.GET_PERMISSIONS;
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+import static android.net.ConnectivityManager.EXTRA_NETWORK;
+import static android.net.ConnectivityManager.EXTRA_NETWORK_REQUEST;
 import static android.net.ConnectivityManager.PROFILE_NETWORK_PREFERENCE_ENTERPRISE;
 import static android.net.ConnectivityManager.TYPE_BLUETOOTH;
 import static android.net.ConnectivityManager.TYPE_ETHERNET;
@@ -74,12 +76,14 @@
 import static com.android.compatibility.common.util.SystemUtil.callWithShellPermissionIdentity;
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
+import static com.android.modules.utils.build.SdkLevel.isAtLeastS;
 import static com.android.networkstack.apishim.ConstantsShim.BLOCKED_REASON_LOCKDOWN_VPN;
 import static com.android.networkstack.apishim.ConstantsShim.BLOCKED_REASON_NONE;
 import static com.android.testutils.MiscAsserts.assertThrows;
 import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
 import static com.android.testutils.TestPermissionUtil.runAsShell;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -111,11 +115,15 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkAgent;
+import android.net.NetworkAgentConfig;
 import android.net.NetworkCapabilities;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkInfo.State;
+import android.net.NetworkProvider;
 import android.net.NetworkRequest;
+import android.net.NetworkScore;
 import android.net.NetworkSpecifier;
 import android.net.NetworkStateSnapshot;
 import android.net.NetworkUtils;
@@ -203,9 +211,12 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Supplier;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -240,6 +251,9 @@
     // Airplane Mode BroadcastReceiver Timeout
     private static final long AIRPLANE_MODE_CHANGE_TIMEOUT_MS = 10_000L;
 
+    // Timeout for applying uids allowed on restricted networks
+    private static final long APPLYING_UIDS_ALLOWED_ON_RESTRICTED_NETWORKS_TIMEOUT_MS = 3_000L;
+
     // Minimum supported keepalive counts for wifi and cellular.
     public static final int MIN_SUPPORTED_CELLULAR_KEEPALIVE_COUNT = 1;
     public static final int MIN_SUPPORTED_WIFI_KEEPALIVE_COUNT = 3;
@@ -849,6 +863,119 @@
         }
     }
 
+    private void runIdenticalPendingIntentsRequestTest(boolean useListen) throws Exception {
+        assumeTrue(mPackageManager.hasSystemFeature(FEATURE_WIFI));
+
+        // Disconnect before registering callbacks, reconnect later to fire them
+        mCtsNetUtils.ensureWifiDisconnected(null);
+
+        final NetworkRequest firstRequest = makeWifiNetworkRequest();
+        final NetworkRequest secondRequest = new NetworkRequest(firstRequest);
+        // Will match wifi or test, since transports are ORed; but there should only be wifi
+        secondRequest.networkCapabilities.addTransportType(TRANSPORT_TEST);
+
+        PendingIntent firstIntent = null;
+        PendingIntent secondIntent = null;
+        BroadcastReceiver receiver = null;
+
+        // Avoid receiving broadcasts from other runs by appending a timestamp
+        final String broadcastAction = NETWORK_CALLBACK_ACTION + System.currentTimeMillis();
+        try {
+            // TODO: replace with PendingIntent.FLAG_MUTABLE when this code compiles against S+
+            // Intent is mutable to receive EXTRA_NETWORK_REQUEST from ConnectivityService
+            final int pendingIntentFlagMutable = 1 << 25;
+            final String extraBoolKey = "extra_bool";
+            firstIntent = PendingIntent.getBroadcast(mContext,
+                    0 /* requestCode */,
+                    new Intent(broadcastAction).putExtra(extraBoolKey, false),
+                    PendingIntent.FLAG_UPDATE_CURRENT | pendingIntentFlagMutable);
+
+            if (useListen) {
+                mCm.registerNetworkCallback(firstRequest, firstIntent);
+            } else {
+                mCm.requestNetwork(firstRequest, firstIntent);
+            }
+
+            // Second intent equals the first as per filterEquals (extras don't count), so first
+            // intent will be updated with the new extras
+            secondIntent = PendingIntent.getBroadcast(mContext,
+                    0 /* requestCode */,
+                    new Intent(broadcastAction).putExtra(extraBoolKey, true),
+                    PendingIntent.FLAG_UPDATE_CURRENT | pendingIntentFlagMutable);
+
+            // Because secondIntent.intentFilterEquals the first, the request should be replaced
+            if (useListen) {
+                mCm.registerNetworkCallback(secondRequest, secondIntent);
+            } else {
+                mCm.requestNetwork(secondRequest, secondIntent);
+            }
+
+            final IntentFilter filter = new IntentFilter();
+            filter.addAction(broadcastAction);
+
+            final CompletableFuture<Network> networkFuture = new CompletableFuture<>();
+            final AtomicInteger receivedCount = new AtomicInteger(0);
+            receiver = new BroadcastReceiver() {
+                @Override
+                public void onReceive(Context context, Intent intent) {
+                    final NetworkRequest request = intent.getParcelableExtra(EXTRA_NETWORK_REQUEST);
+                    assertPendingIntentRequestMatches(request, secondRequest, useListen);
+                    receivedCount.incrementAndGet();
+                    networkFuture.complete(intent.getParcelableExtra(EXTRA_NETWORK));
+                }
+            };
+            mContext.registerReceiver(receiver, filter);
+
+            final Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
+            try {
+                assertEquals(wifiNetwork, networkFuture.get(
+                        NETWORK_CALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS));
+            } catch (TimeoutException e) {
+                throw new AssertionError("PendingIntent not received for " + secondRequest, e);
+            }
+
+            // Sleep for a small amount of time to try to check that only one callback is ever
+            // received (so the first callback was really unregistered). This does not guarantee
+            // that the test will fail if it runs very slowly, but it should at least be very
+            // noticeably flaky.
+            Thread.sleep(NO_CALLBACK_TIMEOUT_MS);
+
+            // TODO: BUG (b/189868426): this should also apply to listens
+            if (!useListen) {
+                assertEquals("PendingIntent should only be received once", 1, receivedCount.get());
+            }
+        } finally {
+            if (firstIntent != null) mCm.unregisterNetworkCallback(firstIntent);
+            if (secondIntent != null) mCm.unregisterNetworkCallback(secondIntent);
+            if (receiver != null) mContext.unregisterReceiver(receiver);
+            mCtsNetUtils.ensureWifiConnected();
+        }
+    }
+
+    private void assertPendingIntentRequestMatches(NetworkRequest broadcasted, NetworkRequest filed,
+            boolean useListen) {
+        // TODO: BUG (b/191713869): on S the request extra is null on listens
+        if (isAtLeastS() && useListen && broadcasted == null) return;
+        assertArrayEquals(filed.networkCapabilities.getCapabilities(),
+                broadcasted.networkCapabilities.getCapabilities());
+        // TODO: BUG (b/189868426): this should also apply to listens
+        if (useListen) return;
+        assertArrayEquals(filed.networkCapabilities.getTransportTypes(),
+                broadcasted.networkCapabilities.getTransportTypes());
+    }
+
+    @AppModeFull(reason = "Cannot get WifiManager in instant app mode")
+    @Test
+    public void testRegisterNetworkRequest_identicalPendingIntents() throws Exception {
+        runIdenticalPendingIntentsRequestTest(false /* useListen */);
+    }
+
+    @AppModeFull(reason = "Cannot get WifiManager in instant app mode")
+    @Test
+    public void testRegisterNetworkCallback_identicalPendingIntents() throws Exception {
+        runIdenticalPendingIntentsRequestTest(true /* useListen */);
+    }
+
     /**
      * Exercises the requestNetwork with NetworkCallback API. This checks to
      * see if we get a callback for an INTERNET request.
@@ -2399,8 +2526,7 @@
         } finally {
             resetValidationConfig();
             // Reconnect wifi to reset the wifi status
-            mCtsNetUtils.ensureWifiDisconnected(null /* wifiNetworkToCheck */);
-            mCtsNetUtils.ensureWifiConnected();
+            reconnectWifi();
         }
     }
 
@@ -2475,6 +2601,88 @@
         }
     }
 
+    @AppModeFull(reason = "WRITE_DEVICE_CONFIG permission can't be granted to instant apps")
+    @Test
+    public void testSetAvoidUnvalidated() throws Exception {
+        assumeTrue(TestUtils.shouldTestSApis());
+        // TODO: Allow in debuggable ROM only. To be replaced by FabricatedOverlay
+        assumeTrue(Build.isDebuggable());
+        final boolean canRunTest = mPackageManager.hasSystemFeature(FEATURE_WIFI)
+                && mPackageManager.hasSystemFeature(FEATURE_TELEPHONY);
+        assumeTrue("testSetAvoidUnvalidated cannot execute"
+                + " unless device supports WiFi and telephony", canRunTest);
+
+        final TestableNetworkCallback wifiCb = new TestableNetworkCallback();
+        final TestableNetworkCallback defaultCb = new TestableNetworkCallback();
+        final int previousAvoidBadWifi =
+                ConnectivitySettingsManager.getNetworkAvoidBadWifi(mContext);
+
+        allowBadWifi();
+
+        final Network cellNetwork = mCtsNetUtils.connectToCell();
+        final Network wifiNetwork = prepareValidatedNetwork();
+
+        mCm.registerDefaultNetworkCallback(defaultCb);
+        mCm.registerNetworkCallback(makeWifiNetworkRequest(), wifiCb);
+
+        try {
+            // Verify wifi is the default network.
+            defaultCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> wifiNetwork.equals(entry.getNetwork()));
+            wifiCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> wifiNetwork.equals(entry.getNetwork()));
+            assertTrue(mCm.getNetworkCapabilities(wifiNetwork).hasCapability(
+                    NET_CAPABILITY_VALIDATED));
+
+            // Configure response code for unvalidated network
+            configTestServer(Status.INTERNAL_ERROR, Status.INTERNAL_ERROR);
+            mCm.reportNetworkConnectivity(wifiNetwork, false);
+            // Default network should stay on unvalidated wifi because avoid bad wifi is disabled.
+            defaultCb.eventuallyExpect(CallbackEntry.NETWORK_CAPS_UPDATED,
+                    NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> !((CallbackEntry.CapabilitiesChanged) entry).getCaps()
+                            .hasCapability(NET_CAPABILITY_VALIDATED));
+            wifiCb.eventuallyExpect(CallbackEntry.NETWORK_CAPS_UPDATED,
+                    NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> !((CallbackEntry.CapabilitiesChanged) entry).getCaps()
+                            .hasCapability(NET_CAPABILITY_VALIDATED));
+
+            runAsShell(NETWORK_SETTINGS, () -> {
+                mCm.setAvoidUnvalidated(wifiNetwork);
+            });
+            // Default network should be updated to validated cellular network.
+            defaultCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> cellNetwork.equals(entry.getNetwork()));
+            // No update on wifi callback.
+            wifiCb.assertNoCallback();
+        } finally {
+            mCm.unregisterNetworkCallback(wifiCb);
+            mCm.unregisterNetworkCallback(defaultCb);
+            resetAvoidBadWifi(previousAvoidBadWifi);
+            resetValidationConfig();
+            // Reconnect wifi to reset the wifi status
+            reconnectWifi();
+        }
+    }
+
+    private void resetAvoidBadWifi(int settingValue) {
+        setTestAllowBadWifiResource(0 /* timeMs */);
+        ConnectivitySettingsManager.setNetworkAvoidBadWifi(mContext, settingValue);
+    }
+
+    private void allowBadWifi() {
+        setTestAllowBadWifiResource(
+                System.currentTimeMillis() + WIFI_CONNECT_TIMEOUT_MS /* timeMs */);
+        ConnectivitySettingsManager.setNetworkAvoidBadWifi(mContext,
+                ConnectivitySettingsManager.NETWORK_AVOID_BAD_WIFI_IGNORE);
+    }
+
+    private void setTestAllowBadWifiResource(long timeMs) {
+        runAsShell(NETWORK_SETTINGS, () -> {
+            mCm.setTestAllowBadWifiUntil(timeMs);
+        });
+    }
+
     private Network expectNetworkHasCapability(Network network, int expectedNetCap, long timeout)
             throws Exception {
         final CompletableFuture<Network> future = new CompletableFuture();
@@ -2514,6 +2722,21 @@
         mHttpServer.start();
     }
 
+    private Network reconnectWifi() {
+        mCtsNetUtils.ensureWifiDisconnected(null /* wifiNetworkToCheck */);
+        return mCtsNetUtils.ensureWifiConnected();
+    }
+
+    private Network prepareValidatedNetwork() throws Exception {
+        prepareHttpServer();
+        configTestServer(Status.NO_CONTENT, Status.NO_CONTENT);
+        // Disconnect wifi first then start wifi network with configuration.
+        final Network wifiNetwork = reconnectWifi();
+
+        return expectNetworkHasCapability(wifiNetwork, NET_CAPABILITY_VALIDATED,
+                WIFI_CONNECT_TIMEOUT_MS);
+    }
+
     private Network preparePartialConnectivity() throws Exception {
         prepareHttpServer();
         // Configure response code for partial connectivity
@@ -2633,4 +2856,110 @@
                     mContext, mobileDataPreferredUids);
         }
     }
+
+    /** Wait for assigned time. */
+    private void waitForMs(long ms) {
+        try {
+            Thread.sleep(ms);
+        } catch (InterruptedException e) {
+            fail("Thread was interrupted");
+        }
+    }
+
+    private void assertBindSocketToNetworkSuccess(final Network network) throws Exception {
+        final CompletableFuture<Boolean> future = new CompletableFuture<>();
+        final ExecutorService executor = Executors.newSingleThreadExecutor();
+        try {
+            executor.execute(() -> {
+                for (int i = 0; i < 30; i++) {
+                    waitForMs(100);
+
+                    try (Socket socket = new Socket()) {
+                        network.bindSocket(socket);
+                        future.complete(true);
+                        return;
+                    } catch (IOException e) { }
+                }
+            });
+            assertTrue(future.get(APPLYING_UIDS_ALLOWED_ON_RESTRICTED_NETWORKS_TIMEOUT_MS,
+                    TimeUnit.MILLISECONDS));
+        } finally {
+            executor.shutdown();
+        }
+    }
+
+    @AppModeFull(reason = "WRITE_SECURE_SETTINGS permission can't be granted to instant apps")
+    @Test
+    public void testUidsAllowedOnRestrictedNetworks() throws Exception {
+        assumeTrue(TestUtils.shouldTestSApis());
+
+        final int uid = mPackageManager.getPackageUid(mContext.getPackageName(), 0 /* flag */);
+        final Set<Integer> originalUidsAllowedOnRestrictedNetworks =
+                ConnectivitySettingsManager.getUidsAllowedOnRestrictedNetworks(mContext);
+        // CtsNetTestCases uid should not list in UIDS_ALLOWED_ON_RESTRICTED_NETWORKS setting
+        // because it has been just installed to device. In case the uid is existed in setting
+        // mistakenly, try to remove the uid and set correct uids to setting.
+        originalUidsAllowedOnRestrictedNetworks.remove(uid);
+        ConnectivitySettingsManager.setUidsAllowedOnRestrictedNetworks(mContext,
+                originalUidsAllowedOnRestrictedNetworks);
+
+        final Handler h = new Handler(Looper.getMainLooper());
+        final TestableNetworkCallback testNetworkCb = new TestableNetworkCallback();
+        mCm.registerBestMatchingNetworkCallback(new NetworkRequest.Builder().clearCapabilities()
+                .addTransportType(NetworkCapabilities.TRANSPORT_TEST).build(), testNetworkCb, h);
+
+        // Create test network agent with restricted network.
+        final NetworkCapabilities nc = new NetworkCapabilities.Builder()
+                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
+                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .build();
+        final NetworkScore score = new NetworkScore.Builder()
+                .setExiting(false)
+                .setTransportPrimary(false)
+                .setKeepConnectedReason(NetworkScore.KEEP_CONNECTED_FOR_HANDOVER)
+                .build();
+        final NetworkAgent agent = new NetworkAgent(mContext, Looper.getMainLooper(),
+                TAG, nc, new LinkProperties(), score, new NetworkAgentConfig.Builder().build(),
+                new NetworkProvider(mContext, Looper.getMainLooper(), TAG)) {};
+        runWithShellPermissionIdentity(() -> agent.register(),
+                android.Manifest.permission.MANAGE_TEST_NETWORKS);
+        agent.markConnected();
+
+        final Network network = agent.getNetwork();
+
+        try (Socket socket = new Socket()) {
+            testNetworkCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> network.equals(entry.getNetwork()));
+            // Verify that the network is restricted.
+            final NetworkCapabilities testNetworkNc = mCm.getNetworkCapabilities(network);
+            assertNotNull(testNetworkNc);
+            assertFalse(testNetworkNc.hasCapability(
+                    NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED));
+            // CtsNetTestCases package doesn't hold CONNECTIVITY_USE_RESTRICTED_NETWORKS, so it
+            // does not allow to bind socket to restricted network.
+            assertThrows(IOException.class, () -> network.bindSocket(socket));
+
+            // Add CtsNetTestCases uid to UIDS_ALLOWED_ON_RESTRICTED_NETWORKS setting, then it can
+            // bind socket to restricted network normally.
+            final Set<Integer> newUidsAllowedOnRestrictedNetworks =
+                    new ArraySet<>(originalUidsAllowedOnRestrictedNetworks);
+            newUidsAllowedOnRestrictedNetworks.add(uid);
+            ConnectivitySettingsManager.setUidsAllowedOnRestrictedNetworks(mContext,
+                    newUidsAllowedOnRestrictedNetworks);
+            // Wait a while for sending allowed uids on the restricted network to netd.
+            // TODD: Have a significant signal to know the uids has been send to netd.
+            assertBindSocketToNetworkSuccess(network);
+        } finally {
+            mCm.unregisterNetworkCallback(testNetworkCb);
+            agent.unregister();
+
+            // Restore setting.
+            ConnectivitySettingsManager.setUidsAllowedOnRestrictedNetworks(mContext,
+                    originalUidsAllowedOnRestrictedNetworks);
+        }
+    }
 }
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
index ae38faa..a9a3380 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
@@ -43,6 +43,7 @@
 import android.net.ConnectivityManager;
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
+import android.net.IpSecManager.IpSecTunnelInterface;
 import android.net.IpSecTransform;
 import android.net.LinkAddress;
 import android.net.Network;
@@ -50,25 +51,33 @@
 import android.net.TestNetworkManager;
 import android.net.cts.PacketUtils.Payload;
 import android.net.cts.util.CtsNetUtils;
+import android.os.Build;
 import android.os.ParcelFileDescriptor;
 import android.platform.test.annotations.AppModeFull;
 
 import androidx.test.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
-import java.net.Inet6Address;
-import java.net.InetAddress;
-import java.net.NetworkInterface;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.net.NetworkInterface;
+
 @RunWith(AndroidJUnit4.class)
 @AppModeFull(reason = "MANAGE_TEST_NETWORKS permission can't be granted to instant apps")
 public class IpSecManagerTunnelTest extends IpSecBaseTest {
+    @Rule public final DevSdkIgnoreRule ignoreRule = new DevSdkIgnoreRule();
+
     private static final String TAG = IpSecManagerTunnelTest.class.getSimpleName();
 
     private static final InetAddress LOCAL_OUTER_4 = InetAddress.parseNumericAddress("192.0.2.1");
@@ -78,6 +87,15 @@
     private static final InetAddress REMOTE_OUTER_6 =
             InetAddress.parseNumericAddress("2001:db8:1::2");
 
+    private static final InetAddress LOCAL_OUTER_4_NEW =
+            InetAddress.parseNumericAddress("192.0.2.101");
+    private static final InetAddress REMOTE_OUTER_4_NEW =
+            InetAddress.parseNumericAddress("192.0.2.102");
+    private static final InetAddress LOCAL_OUTER_6_NEW =
+            InetAddress.parseNumericAddress("2001:db8:1::101");
+    private static final InetAddress REMOTE_OUTER_6_NEW =
+            InetAddress.parseNumericAddress("2001:db8:1::102");
+
     private static final InetAddress LOCAL_INNER_4 =
             InetAddress.parseNumericAddress("198.51.100.1");
     private static final InetAddress REMOTE_INNER_4 =
@@ -95,10 +113,9 @@
     // Static state to reduce setup/teardown
     private static ConnectivityManager sCM;
     private static TestNetworkManager sTNM;
-    private static ParcelFileDescriptor sTunFd;
-    private static TestNetworkCallback sTunNetworkCallback;
-    private static Network sTunNetwork;
-    private static TunUtils sTunUtils;
+
+    private static TunNetworkWrapper sTunWrapper;
+    private static TunNetworkWrapper sTunWrapperNew;
 
     private static Context sContext = InstrumentationRegistry.getContext();
     private static final CtsNetUtils mCtsNetUtils = new CtsNetUtils(sContext);
@@ -116,19 +133,8 @@
         // right appop permissions.
         mCtsNetUtils.setAppopPrivileged(OP_MANAGE_IPSEC_TUNNELS, true);
 
-        TestNetworkInterface testIface =
-                sTNM.createTunInterface(
-                        new LinkAddress[] {
-                            new LinkAddress(LOCAL_OUTER_4, IP4_PREFIX_LEN),
-                            new LinkAddress(LOCAL_OUTER_6, IP6_PREFIX_LEN)
-                        });
-
-        sTunFd = testIface.getFileDescriptor();
-        sTunNetworkCallback = mCtsNetUtils.setupAndGetTestNetwork(testIface.getInterfaceName());
-        sTunNetworkCallback.waitForAvailable();
-        sTunNetwork = sTunNetworkCallback.currentNetwork;
-
-        sTunUtils = new TunUtils(sTunFd);
+        sTunWrapper = new TunNetworkWrapper(LOCAL_OUTER_4, LOCAL_OUTER_6);
+        sTunWrapperNew = new TunNetworkWrapper(LOCAL_OUTER_4_NEW, LOCAL_OUTER_6_NEW);
     }
 
     @Before
@@ -139,24 +145,76 @@
         // Set to true before every run; some tests flip this.
         mCtsNetUtils.setAppopPrivileged(OP_MANAGE_IPSEC_TUNNELS, true);
 
-        // Clear sTunUtils state
-        sTunUtils.reset();
+        // Clear TunUtils state
+        sTunWrapper.utils.reset();
+        sTunWrapperNew.utils.reset();
+    }
+
+    private static void tearDownTunWrapperIfNotNull(TunNetworkWrapper tunWrapper) throws Exception {
+        if (tunWrapper != null) {
+            tunWrapper.tearDown();
+        }
     }
 
     @AfterClass
     public static void tearDownAfterClass() throws Exception {
         mCtsNetUtils.setAppopPrivileged(OP_MANAGE_IPSEC_TUNNELS, false);
 
-        sCM.unregisterNetworkCallback(sTunNetworkCallback);
-
-        sTNM.teardownTestNetwork(sTunNetwork);
-        sTunFd.close();
+        tearDownTunWrapperIfNotNull(sTunWrapper);
+        tearDownTunWrapperIfNotNull(sTunWrapperNew);
 
         InstrumentationRegistry.getInstrumentation()
                 .getUiAutomation()
                 .dropShellPermissionIdentity();
     }
 
+    private static class TunNetworkWrapper {
+        public final ParcelFileDescriptor fd;
+        public final TestNetworkCallback networkCallback;
+        public final Network network;
+        public final TunUtils utils;
+
+        TunNetworkWrapper(InetAddress... addresses) throws Exception {
+            final LinkAddress[] linkAddresses = new LinkAddress[addresses.length];
+            for (int i = 0; i < linkAddresses.length; i++) {
+                InetAddress addr = addresses[i];
+                if (addr instanceof Inet4Address) {
+                    linkAddresses[i] = new LinkAddress(addr, IP4_PREFIX_LEN);
+                } else {
+                    linkAddresses[i] = new LinkAddress(addr, IP6_PREFIX_LEN);
+                }
+            }
+
+            try {
+                final TestNetworkInterface testIface = sTNM.createTunInterface(linkAddresses);
+
+                fd = testIface.getFileDescriptor();
+                networkCallback = mCtsNetUtils.setupAndGetTestNetwork(testIface.getInterfaceName());
+                networkCallback.waitForAvailable();
+                network = networkCallback.currentNetwork;
+            } catch (Exception e) {
+                tearDown();
+                throw e;
+            }
+
+            utils = new TunUtils(fd);
+        }
+
+        public void tearDown() throws Exception {
+            if (networkCallback != null) {
+                sCM.unregisterNetworkCallback(networkCallback);
+            }
+
+            if (network != null) {
+                sTNM.teardownTestNetwork(network);
+            }
+
+            if (fd != null) {
+                fd.close();
+            }
+        }
+    }
+
     @Test
     public void testSecurityExceptionCreateTunnelInterfaceWithoutAppop() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -166,7 +224,7 @@
 
         // Security exceptions are thrown regardless of IPv4/IPv6. Just test one
         try {
-            mISM.createIpSecTunnelInterface(LOCAL_INNER_6, REMOTE_INNER_6, sTunNetwork);
+            mISM.createIpSecTunnelInterface(LOCAL_INNER_6, REMOTE_INNER_6, sTunWrapper.network);
             fail("Did not throw SecurityException for Tunnel creation without appop");
         } catch (SecurityException expected) {
         }
@@ -196,11 +254,16 @@
          * Runs the test code, and returns the inner socket port, if any.
          *
          * @param ipsecNetwork The IPsec Interface based Network for binding sockets on
+         * @param tunnelIface The IPsec tunnel interface that will be tested
+         * @param underlyingTunUtils The utility of the IPsec tunnel interface's underlying TUN
+         *     network
          * @return the integer port of the inner socket if outbound, or 0 if inbound
          *     IpSecTunnelTestRunnable
          * @throws Exception if any part of the test failed.
          */
-        public abstract int run(Network ipsecNetwork) throws Exception;
+        public abstract int run(
+                Network ipsecNetwork, IpSecTunnelInterface tunnelIface, TunUtils underlyingTunUtils)
+                throws Exception;
     }
 
     private int getPacketSize(
@@ -265,7 +328,9 @@
                 int expectedPacketSize) {
             return new IpSecTunnelTestRunnable() {
                 @Override
-                public int run(Network ipsecNetwork) throws Exception {
+                public int run(
+                        Network ipsecNetwork, IpSecTunnelInterface tunnelIface, TunUtils tunUtils)
+                        throws Exception {
                     // Build a socket and send traffic
                     JavaUdpSocket socket = new JavaUdpSocket(localInner);
                     ipsecNetwork.bindSocket(socket.mSocket);
@@ -284,7 +349,7 @@
                     // Verify that an encrypted packet is sent. As of right now, checking encrypted
                     // body is not possible, due to the test not knowing some of the fields of the
                     // inner IP header (flow label, flags, etc)
-                    sTunUtils.awaitEspPacketNoPlaintext(
+                    tunUtils.awaitEspPacketNoPlaintext(
                             spi, TEST_DATA, encapPort != 0, expectedPacketSize);
 
                     socket.close();
@@ -312,7 +377,9 @@
                 throws Exception {
             return new IpSecTunnelTestRunnable() {
                 @Override
-                public int run(Network ipsecNetwork) throws Exception {
+                public int run(
+                        Network ipsecNetwork, IpSecTunnelInterface tunnelIface, TunUtils tunUtils)
+                        throws Exception {
                     // Build a socket and receive traffic
                     JavaUdpSocket socket = new JavaUdpSocket(localInner, innerSocketPort);
                     ipsecNetwork.bindSocket(socket.mSocket);
@@ -325,7 +392,7 @@
                                 socket.mSocket, IpSecManager.DIRECTION_OUT, inTransportTransform);
                     }
 
-                    sTunUtils.reflectPackets();
+                    tunUtils.reflectPackets();
 
                     // Receive packet from socket, and validate that the payload is correct
                     receiveAndValidatePacket(socket);
@@ -355,7 +422,9 @@
                 throws Exception {
             return new IpSecTunnelTestRunnable() {
                 @Override
-                public int run(Network ipsecNetwork) throws Exception {
+                public int run(
+                        Network ipsecNetwork, IpSecTunnelInterface tunnelIface, TunUtils tunUtils)
+                        throws Exception {
                     // Build a socket and receive traffic
                     JavaUdpSocket socket = new JavaUdpSocket(localInner);
                     ipsecNetwork.bindSocket(socket.mSocket);
@@ -391,7 +460,7 @@
                                         socket.getPort(),
                                         encapPort);
                     }
-                    sTunUtils.injectPacket(pkt);
+                    tunUtils.injectPacket(pkt);
 
                     // Receive packet from socket, and validate
                     receiveAndValidatePacket(socket);
@@ -404,6 +473,161 @@
         }
     }
 
+    private class MigrateIpSecTunnelTestRunnableFactory implements IpSecTunnelTestRunnableFactory {
+        private final IpSecTunnelTestRunnableFactory mTestRunnableFactory;
+
+        MigrateIpSecTunnelTestRunnableFactory(boolean isOutputTest) {
+            if (isOutputTest) {
+                mTestRunnableFactory = new OutputIpSecTunnelTestRunnableFactory();
+            } else {
+                mTestRunnableFactory = new InputPacketGeneratorIpSecTunnelTestRunnableFactory();
+            }
+        }
+
+        @Override
+        public IpSecTunnelTestRunnable getIpSecTunnelTestRunnable(
+                boolean transportInTunnelMode,
+                int spi,
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                IpSecTransform inTransportTransform,
+                IpSecTransform outTransportTransform,
+                int encapPort,
+                int unusedInnerSocketPort,
+                int expectedPacketSize) {
+            return new IpSecTunnelTestRunnable() {
+                @Override
+                public int run(
+                        Network ipsecNetwork, IpSecTunnelInterface tunnelIface, TunUtils tunUtils)
+                        throws Exception {
+                    mTestRunnableFactory
+                            .getIpSecTunnelTestRunnable(
+                                    transportInTunnelMode,
+                                    spi,
+                                    localInner,
+                                    remoteInner,
+                                    localOuter,
+                                    remoteOuter,
+                                    inTransportTransform,
+                                    outTransportTransform,
+                                    encapPort,
+                                    unusedInnerSocketPort,
+                                    expectedPacketSize)
+                            .run(ipsecNetwork, tunnelIface, sTunWrapper.utils);
+
+                    tunnelIface.setUnderlyingNetwork(sTunWrapperNew.network);
+
+                    // Verify migrating to IPv4 and IPv6 addresses. It ensures that not only
+                    // can IPsec tunnel migrate across interfaces, IPsec tunnel can also migrate to
+                    // a different address on the same interface.
+                    checkMigratedTunnel(
+                            localInner,
+                            remoteInner,
+                            LOCAL_OUTER_4_NEW,
+                            REMOTE_OUTER_4_NEW,
+                            encapPort != 0,
+                            transportInTunnelMode,
+                            sTunWrapperNew.utils,
+                            tunnelIface,
+                            ipsecNetwork);
+                    checkMigratedTunnel(
+                            localInner,
+                            remoteInner,
+                            LOCAL_OUTER_6_NEW,
+                            REMOTE_OUTER_6_NEW,
+                            false, // IPv6 does not support UDP encapsulation
+                            transportInTunnelMode,
+                            sTunWrapperNew.utils,
+                            tunnelIface,
+                            ipsecNetwork);
+
+                    return 0;
+                }
+            };
+        }
+
+        private void checkMigratedTunnel(
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                boolean useEncap,
+                boolean transportInTunnelMode,
+                TunUtils tunUtils,
+                IpSecTunnelInterface tunnelIface,
+                Network ipsecNetwork)
+                throws Exception {
+
+            // Preselect both SPI and encap port, to be used for both inbound and outbound tunnels.
+            // Re-uses the same SPI to ensure that even in cases of symmetric SPIs shared across
+            // tunnel and transport mode, packets are encrypted/decrypted properly based on the
+            // src/dst.
+            int spi = getRandomSpi(localOuter, remoteOuter);
+
+            int innerFamily = localInner instanceof Inet4Address ? AF_INET : AF_INET6;
+            int outerFamily = localOuter instanceof Inet4Address ? AF_INET : AF_INET6;
+            int expectedPacketSize =
+                    getPacketSize(innerFamily, outerFamily, useEncap, transportInTunnelMode);
+
+            // Build transport mode transforms and encapsulation socket for verifying
+            // transport-in-tunnel case and encapsulation case.
+            try (IpSecManager.SecurityParameterIndex inTransportSpi =
+                            mISM.allocateSecurityParameterIndex(localInner, spi);
+                    IpSecManager.SecurityParameterIndex outTransportSpi =
+                            mISM.allocateSecurityParameterIndex(remoteInner, spi);
+                    IpSecTransform inTransportTransform =
+                            buildIpSecTransform(sContext, inTransportSpi, null, remoteInner);
+                    IpSecTransform outTransportTransform =
+                            buildIpSecTransform(sContext, outTransportSpi, null, localInner);
+                    UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+
+                // Configure tunnel mode Transform parameters
+                IpSecTransform.Builder transformBuilder = new IpSecTransform.Builder(sContext);
+                transformBuilder.setEncryption(
+                        new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY));
+                transformBuilder.setAuthentication(
+                        new IpSecAlgorithm(
+                                IpSecAlgorithm.AUTH_HMAC_SHA256, AUTH_KEY, AUTH_KEY.length * 4));
+
+                if (useEncap) {
+                    transformBuilder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
+                }
+
+                // Apply transform and check that traffic is properly encrypted
+                try (IpSecManager.SecurityParameterIndex inSpi =
+                                mISM.allocateSecurityParameterIndex(localOuter, spi);
+                        IpSecManager.SecurityParameterIndex outSpi =
+                                mISM.allocateSecurityParameterIndex(remoteOuter, spi);
+                        IpSecTransform inTransform =
+                                transformBuilder.buildTunnelModeTransform(remoteOuter, inSpi);
+                        IpSecTransform outTransform =
+                                transformBuilder.buildTunnelModeTransform(localOuter, outSpi)) {
+                    mISM.applyTunnelModeTransform(
+                            tunnelIface, IpSecManager.DIRECTION_IN, inTransform);
+                    mISM.applyTunnelModeTransform(
+                            tunnelIface, IpSecManager.DIRECTION_OUT, outTransform);
+
+                    mTestRunnableFactory
+                            .getIpSecTunnelTestRunnable(
+                                    transportInTunnelMode,
+                                    spi,
+                                    localInner,
+                                    remoteInner,
+                                    localOuter,
+                                    remoteOuter,
+                                    inTransportTransform,
+                                    outTransportTransform,
+                                    useEncap ? encapSocket.getPort() : 0,
+                                    0,
+                                    expectedPacketSize)
+                            .run(ipsecNetwork, tunnelIface, tunUtils);
+                }
+            }
+        }
+    }
+
     private void checkTunnelOutput(
             int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
             throws Exception {
@@ -426,6 +650,28 @@
                 new InputPacketGeneratorIpSecTunnelTestRunnableFactory());
     }
 
+    private void checkMigrateTunnelOutput(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        checkTunnel(
+                innerFamily,
+                outerFamily,
+                useEncap,
+                transportInTunnelMode,
+                new MigrateIpSecTunnelTestRunnableFactory(true));
+    }
+
+    private void checkMigrateTunnelInput(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        checkTunnel(
+                innerFamily,
+                outerFamily,
+                useEncap,
+                transportInTunnelMode,
+                new MigrateIpSecTunnelTestRunnableFactory(false));
+    }
+
     /**
      * Validates that the kernel can talk to itself.
      *
@@ -579,7 +825,8 @@
                 IpSecManager.SecurityParameterIndex outSpi =
                         mISM.allocateSecurityParameterIndex(remoteOuter, spi);
                 IpSecManager.IpSecTunnelInterface tunnelIface =
-                        mISM.createIpSecTunnelInterface(localOuter, remoteOuter, sTunNetwork)) {
+                        mISM.createIpSecTunnelInterface(
+                                localOuter, remoteOuter, sTunWrapper.network)) {
             // Build the test network
             tunnelIface.addAddress(localInner, innerPrefixLen);
             testNetworkCb = mCtsNetUtils.setupAndGetTestNetwork(tunnelIface.getInterfaceName());
@@ -615,7 +862,7 @@
                 mISM.applyTunnelModeTransform(
                         tunnelIface, IpSecManager.DIRECTION_OUT, outTransform);
 
-                innerSocketPort = test.run(testNetwork);
+                innerSocketPort = test.run(testNetwork, tunnelIface, sTunWrapper.utils);
             }
 
             // Teardown the test network
@@ -739,6 +986,14 @@
         return maybeEncapPacket(srcOuter, dstOuter, encapPort, espPayload).getPacketBytes();
     }
 
+    private void doTestMigrateTunnel(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
+        checkTunnelOutput(innerFamily, outerFamily, useEncap, transportInTunnelMode);
+        checkTunnelInput(innerFamily, outerFamily, useEncap, transportInTunnelMode);
+    }
+
     // Transport-in-Tunnel mode tests
     @Test
     public void testTransportInTunnelModeV4InV4() throws Exception {
@@ -747,6 +1002,12 @@
         checkTunnelInput(AF_INET, AF_INET, false, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV4InV4() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET, false, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV4InV4Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -760,6 +1021,12 @@
         checkTunnelInput(AF_INET, AF_INET, true, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV4InV4UdpEncap() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET, true, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV4InV4UdpEncapReflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -773,6 +1040,12 @@
         checkTunnelInput(AF_INET, AF_INET6, false, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV4InV6() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET6, false, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV4InV6Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -786,6 +1059,12 @@
         checkTunnelInput(AF_INET6, AF_INET, false, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV6InV4() throws Exception {
+        doTestMigrateTunnel(AF_INET6, AF_INET, false, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV6InV4Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -799,6 +1078,12 @@
         checkTunnelInput(AF_INET6, AF_INET, true, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV6InV4UdpEncap() throws Exception {
+        doTestMigrateTunnel(AF_INET6, AF_INET, true, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV6InV4UdpEncapReflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -812,6 +1097,12 @@
         checkTunnelInput(AF_INET, AF_INET6, false, true);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTransportInTunnelModeV6InV6() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET6, false, true);
+    }
+
     @Test
     public void testTransportInTunnelModeV6InV6Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -826,6 +1117,12 @@
         checkTunnelInput(AF_INET, AF_INET, false, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV4InV4() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET, false, false);
+    }
+
     @Test
     public void testTunnelV4InV4Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -839,6 +1136,12 @@
         checkTunnelInput(AF_INET, AF_INET, true, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV4InV4UdpEncap() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET, true, false);
+    }
+
     @Test
     public void testTunnelV4InV4UdpEncapReflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -852,6 +1155,12 @@
         checkTunnelInput(AF_INET, AF_INET6, false, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV4InV6() throws Exception {
+        doTestMigrateTunnel(AF_INET, AF_INET6, false, false);
+    }
+
     @Test
     public void testTunnelV4InV6Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -865,6 +1174,12 @@
         checkTunnelInput(AF_INET6, AF_INET, false, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV6InV4() throws Exception {
+        doTestMigrateTunnel(AF_INET6, AF_INET, false, false);
+    }
+
     @Test
     public void testTunnelV6InV4Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -878,6 +1193,12 @@
         checkTunnelInput(AF_INET6, AF_INET, true, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV6InV4UdpEncap() throws Exception {
+        doTestMigrateTunnel(AF_INET6, AF_INET, true, false);
+    }
+
     @Test
     public void testTunnelV6InV4UdpEncapReflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
@@ -891,6 +1212,12 @@
         checkTunnelInput(AF_INET6, AF_INET6, false, false);
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    @Test
+    public void testMigrateTunnelV6InV6() throws Exception {
+        doTestMigrateTunnel(AF_INET6, AF_INET6, false, false);
+    }
+
     @Test
     public void testTunnelV6InV6Reflected() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index c505cef..ccc9416 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -29,9 +29,9 @@
 import android.net.NattKeepalivePacketData
 import android.net.Network
 import android.net.NetworkAgent
+import android.net.NetworkAgentConfig
 import android.net.NetworkAgent.INVALID_NETWORK
 import android.net.NetworkAgent.VALID_NETWORK
-import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
 import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_CONGESTED
@@ -46,9 +46,17 @@
 import android.net.NetworkCapabilities.TRANSPORT_VPN
 import android.net.NetworkInfo
 import android.net.NetworkProvider
+import android.net.NetworkReleasedException
 import android.net.NetworkRequest
 import android.net.NetworkScore
 import android.net.RouteInfo
+import android.net.QosCallback
+import android.net.QosCallbackException
+import android.net.QosCallback.QosCallbackRegistrationException
+import android.net.QosFilter
+import android.net.QosSession
+import android.net.QosSessionAttributes
+import android.net.QosSocketInfo
 import android.net.SocketKeepalive
 import android.net.Uri
 import android.net.VpnManager
@@ -59,12 +67,17 @@
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkDestroyed
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRegisterQosCallback
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRemoveKeepalivePacketFilter
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnSaveAcceptUnvalidated
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnSignalStrengthThresholdsUpdated
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStartSocketKeepalive
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnUnregisterQosCallback
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnValidationStatus
+import android.net.cts.NetworkAgentTest.TestableQosCallback.CallbackEntry.OnError
+import android.net.cts.NetworkAgentTest.TestableQosCallback.CallbackEntry.OnQosSessionAvailable
+import android.net.cts.NetworkAgentTest.TestableQosCallback.CallbackEntry.OnQosSessionLost
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
@@ -72,6 +85,7 @@
 import android.os.Message
 import android.os.SystemClock
 import android.telephony.TelephonyManager
+import android.telephony.data.EpsBearerQosSessionAttributes
 import android.util.DebugUtils.valueToString
 import androidx.test.InstrumentationRegistry
 import com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity
@@ -97,9 +111,13 @@
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.timeout
 import org.mockito.Mockito.verify
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.Socket
 import java.time.Duration
 import java.util.Arrays
 import java.util.UUID
+import java.util.concurrent.Executors
 import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
@@ -143,7 +161,7 @@
     private val LOCAL_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1")
     private val REMOTE_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.2")
 
-    private val mCM = realContext.getSystemService(ConnectivityManager::class.java)
+    private val mCM = realContext.getSystemService(ConnectivityManager::class.java)!!
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
     private val mFakeConnectivityService = FakeConnectivityService()
 
@@ -152,6 +170,7 @@
 
     private val agentsToCleanUp = mutableListOf<NetworkAgent>()
     private val callbacksToCleanUp = mutableListOf<TestableNetworkCallback>()
+    private var qosTestSocket: Socket? = null
 
     @Before
     fun setUp() {
@@ -163,6 +182,7 @@
     fun tearDown() {
         agentsToCleanUp.forEach { it.unregister() }
         callbacksToCleanUp.forEach { mCM.unregisterNetworkCallback(it) }
+        qosTestSocket?.close()
         mHandlerThread.quitSafely()
         instrumentation.getUiAutomation().dropShellPermissionIdentity()
     }
@@ -228,6 +248,11 @@
             data class OnSignalStrengthThresholdsUpdated(val thresholds: IntArray) : CallbackEntry()
             object OnNetworkCreated : CallbackEntry()
             object OnNetworkDestroyed : CallbackEntry()
+            data class OnRegisterQosCallback(
+                val callbackId: Int,
+                val filter: QosFilter
+            ) : CallbackEntry()
+            data class OnUnregisterQosCallback(val callbackId: Int) : CallbackEntry()
         }
 
         override fun onBandwidthUpdateRequested() {
@@ -276,6 +301,14 @@
             }
         }
 
+        override fun onQosCallbackRegistered(qosCallbackId: Int, filter: QosFilter) {
+            history.add(OnRegisterQosCallback(qosCallbackId, filter))
+        }
+
+        override fun onQosCallbackUnregistered(qosCallbackId: Int) {
+            history.add(OnUnregisterQosCallback(qosCallbackId))
+        }
+
         override fun onValidationStatus(status: Int, uri: Uri?) {
             history.add(OnValidationStatus(status, uri))
         }
@@ -307,6 +340,12 @@
             return foundCallback
         }
 
+        inline fun <reified T : CallbackEntry> expectCallback(valid: (T) -> Boolean) {
+            val foundCallback = history.poll(DEFAULT_TIMEOUT_MS)
+            assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback")
+            assertTrue(valid(foundCallback), "Unexpected callback : $foundCallback")
+        }
+
         inline fun <reified T : CallbackEntry> eventuallyExpect() =
                 history.poll(DEFAULT_TIMEOUT_MS) { it is T }.also {
                     assertNotNull(it, "Callback ${T::class} not received")
@@ -390,7 +429,7 @@
         initialConfig: NetworkAgentConfig? = null,
         expectedInitSignalStrengthThresholds: IntArray? = intArrayOf()
     ): Pair<TestableNetworkAgent, TestableNetworkCallback> {
-        val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        val callback = TestableNetworkCallback()
         // Ensure this NetworkAgent is never unneeded by filing a request with its specifier.
         requestNetwork(makeTestNetworkRequest(specifier = specifier), callback)
         val agent = createNetworkAgent(context, specifier, initialConfig = initialConfig)
@@ -651,7 +690,7 @@
         assertFalse(vpnNc.hasCapability(NET_CAPABILITY_NOT_VPN))
         assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports),
                 "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" +
-                " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
+                        " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
 
         // Check that when no underlying networks are announced the underlying transport disappears.
         agent.setUnderlyingNetworks(listOf<Network>())
@@ -934,4 +973,251 @@
 
         // tearDown() will unregister the requests and agents
     }
+
+    private class TestableQosCallback : QosCallback() {
+        val history = ArrayTrackRecord<CallbackEntry>().newReadHead()
+
+        sealed class CallbackEntry {
+            data class OnQosSessionAvailable(val sess: QosSession, val attr: QosSessionAttributes)
+                : CallbackEntry()
+            data class OnQosSessionLost(val sess: QosSession)
+                : CallbackEntry()
+            data class OnError(val ex: QosCallbackException)
+                : CallbackEntry()
+        }
+
+        override fun onQosSessionAvailable(sess: QosSession, attr: QosSessionAttributes) {
+            history.add(OnQosSessionAvailable(sess, attr))
+        }
+
+        override fun onQosSessionLost(sess: QosSession) {
+            history.add(OnQosSessionLost(sess))
+        }
+
+        override fun onError(ex: QosCallbackException) {
+            history.add(OnError(ex))
+        }
+
+        inline fun <reified T : CallbackEntry> expectCallback(): T {
+            val foundCallback = history.poll(DEFAULT_TIMEOUT_MS)
+            assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback")
+            return foundCallback
+        }
+
+        inline fun <reified T : CallbackEntry> expectCallback(valid: (T) -> Boolean) {
+            val foundCallback = history.poll(DEFAULT_TIMEOUT_MS)
+            assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback")
+            assertTrue(valid(foundCallback), "Unexpected callback : $foundCallback")
+        }
+
+        fun assertNoCallback() {
+            assertNull(history.poll(NO_CALLBACK_TIMEOUT),
+                    "Callback received")
+        }
+    }
+
+    private fun setupForQosCallbackTesting(): Pair<TestableNetworkAgent, Socket> {
+        val request = NetworkRequest.Builder()
+                .clearCapabilities()
+                .addTransportType(TRANSPORT_TEST)
+                .build()
+
+        val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        requestNetwork(request, callback)
+        val (agent, _) = createConnectedNetworkAgent()
+
+        qosTestSocket = assertNotNull(agent.network?.socketFactory?.createSocket()).also {
+            it.bind(InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
+        }
+        return Pair(agent, qosTestSocket!!)
+    }
+
+    @Test
+    fun testQosCallbackRegisterWithUnregister() {
+        val (agent, socket) = setupForQosCallbackTesting()
+
+        val qosCallback = TestableQosCallback()
+        var callbackId = -1
+        Executors.newSingleThreadExecutor().let { executor ->
+            try {
+                val info = QosSocketInfo(agent.network!!, socket)
+                mCM.registerQosCallback(info, executor, qosCallback)
+                callbackId = agent.expectCallback<OnRegisterQosCallback>().callbackId
+
+                assertFailsWith<QosCallbackRegistrationException>(
+                        "The same callback cannot be " +
+                        "registered more than once without first being unregistered") {
+                    mCM.registerQosCallback(info, executor, qosCallback)
+                }
+            } finally {
+                socket.close()
+                mCM.unregisterQosCallback(qosCallback)
+                agent.expectCallback<OnUnregisterQosCallback> { it.callbackId == callbackId }
+                executor.shutdown()
+            }
+        }
+    }
+
+    @Test
+    fun testQosCallbackOnQosSession() {
+        val (agent, socket) = setupForQosCallbackTesting()
+        val qosCallback = TestableQosCallback()
+        Executors.newSingleThreadExecutor().let { executor ->
+            try {
+                val info = QosSocketInfo(agent.network!!, socket)
+                mCM.registerQosCallback(info, executor, qosCallback)
+                val callbackId = agent.expectCallback<OnRegisterQosCallback>().callbackId
+
+                val uniqueSessionId = 4294967397
+                val sessId = 101
+
+                val attributes = createEpsAttributes(5)
+                assertEquals(attributes.qosIdentifier, 5)
+                agent.sendQosSessionAvailable(callbackId, sessId, attributes)
+                qosCallback.expectCallback<OnQosSessionAvailable> {
+                            it.sess.sessionId == sessId && it.sess.uniqueId == uniqueSessionId &&
+                                it.sess.sessionType == QosSession.TYPE_EPS_BEARER
+                        }
+
+                agent.sendQosSessionLost(callbackId, sessId, QosSession.TYPE_EPS_BEARER)
+                qosCallback.expectCallback<OnQosSessionLost> {
+                            it.sess.sessionId == sessId && it.sess.uniqueId == uniqueSessionId &&
+                                it.sess.sessionType == QosSession.TYPE_EPS_BEARER
+                        }
+
+                // Make sure that we don't get more qos callbacks
+                mCM.unregisterQosCallback(qosCallback)
+                agent.expectCallback<OnUnregisterQosCallback>()
+
+                agent.sendQosSessionLost(callbackId, sessId, QosSession.TYPE_EPS_BEARER)
+                qosCallback.assertNoCallback()
+            } finally {
+                socket.close()
+
+                // safety precaution
+                mCM.unregisterQosCallback(qosCallback)
+
+                executor.shutdown()
+            }
+        }
+    }
+
+    @Test
+    fun testQosCallbackOnError() {
+        val (agent, socket) = setupForQosCallbackTesting()
+        val qosCallback = TestableQosCallback()
+        Executors.newSingleThreadExecutor().let { executor ->
+            try {
+                val info = QosSocketInfo(agent.network!!, socket)
+                mCM.registerQosCallback(info, executor, qosCallback)
+                val callbackId = agent.expectCallback<OnRegisterQosCallback>().callbackId
+
+                val sessId = 101
+                val attributes = createEpsAttributes()
+
+                // Double check that this is wired up and ready to go
+                agent.sendQosSessionAvailable(callbackId, sessId, attributes)
+                qosCallback.expectCallback<OnQosSessionAvailable>()
+
+                // Check that onError is coming through correctly
+                agent.sendQosCallbackError(callbackId,
+                        QosCallbackException.EX_TYPE_FILTER_NOT_SUPPORTED)
+                qosCallback.expectCallback<OnError> {
+                    it.ex.cause is UnsupportedOperationException
+                }
+
+                // Ensure that when an error occurs the callback was also unregistered
+                agent.sendQosSessionLost(callbackId, sessId, QosSession.TYPE_EPS_BEARER)
+                qosCallback.assertNoCallback()
+            } finally {
+                socket.close()
+
+                // Make sure that the callback is fully unregistered
+                mCM.unregisterQosCallback(qosCallback)
+
+                executor.shutdown()
+            }
+        }
+    }
+
+    @Test
+    fun testQosCallbackIdsAreMappedCorrectly() {
+        val (agent, socket) = setupForQosCallbackTesting()
+        val qosCallback1 = TestableQosCallback()
+        val qosCallback2 = TestableQosCallback()
+        Executors.newSingleThreadExecutor().let { executor ->
+            try {
+                val info = QosSocketInfo(agent.network!!, socket)
+                mCM.registerQosCallback(info, executor, qosCallback1)
+                val callbackId1 = agent.expectCallback<OnRegisterQosCallback>().callbackId
+
+                mCM.registerQosCallback(info, executor, qosCallback2)
+                val callbackId2 = agent.expectCallback<OnRegisterQosCallback>().callbackId
+
+                val sessId1 = 101
+                val attributes1 = createEpsAttributes(1)
+
+                // Check #1
+                agent.sendQosSessionAvailable(callbackId1, sessId1, attributes1)
+                qosCallback1.expectCallback<OnQosSessionAvailable>()
+                qosCallback2.assertNoCallback()
+
+                // Check #2
+                val sessId2 = 102
+                val attributes2 = createEpsAttributes(2)
+                agent.sendQosSessionAvailable(callbackId2, sessId2, attributes2)
+                qosCallback1.assertNoCallback()
+                qosCallback2.expectCallback<OnQosSessionAvailable> { sessId2 == it.sess.sessionId }
+            } finally {
+                socket.close()
+
+                // Make sure that the callback is fully unregistered
+                mCM.unregisterQosCallback(qosCallback1)
+                mCM.unregisterQosCallback(qosCallback2)
+
+                executor.shutdown()
+            }
+        }
+    }
+
+    @Test
+    fun testQosCallbackWhenNetworkReleased() {
+        val (agent, socket) = setupForQosCallbackTesting()
+        Executors.newSingleThreadExecutor().let { executor ->
+            try {
+                val qosCallback1 = TestableQosCallback()
+                val qosCallback2 = TestableQosCallback()
+                try {
+                    val info = QosSocketInfo(agent.network!!, socket)
+                    mCM.registerQosCallback(info, executor, qosCallback1)
+                    mCM.registerQosCallback(info, executor, qosCallback2)
+                    agent.unregister()
+
+                    qosCallback1.expectCallback<OnError> {
+                        it.ex.cause is NetworkReleasedException
+                    }
+
+                    qosCallback2.expectCallback<OnError> {
+                        it.ex.cause is NetworkReleasedException
+                    }
+                } finally {
+                    socket.close()
+                    mCM.unregisterQosCallback(qosCallback1)
+                    mCM.unregisterQosCallback(qosCallback2)
+                }
+            } finally {
+                socket.close()
+                executor.shutdown()
+            }
+        }
+    }
+
+    private fun createEpsAttributes(qci: Int = 1): EpsBearerQosSessionAttributes {
+        val remoteAddresses = ArrayList<InetSocketAddress>()
+        remoteAddresses.add(InetSocketAddress("2001:db8::123", 80))
+        return EpsBearerQosSessionAttributes(
+                qci, 2, 3, 4, 5,
+                remoteAddresses
+        )
+    }
 }
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index beae0cf..ba54273 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -60,7 +60,6 @@
         "java/**/*.kt",
     ],
     test_suites: ["device-tests"],
-    certificate: "platform",
     jarjar_rules: "jarjar-rules.txt",
     static_libs: [
         "androidx.test.rules",
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index be7239d..f6ea964 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -18,10 +18,15 @@
 
 import static android.Manifest.permission.CHANGE_NETWORK_STATE;
 import static android.Manifest.permission.CONNECTIVITY_USE_RESTRICTED_NETWORKS;
+import static android.Manifest.permission.CONTROL_OEM_PAID_NETWORK_PREFERENCE;
+import static android.Manifest.permission.CREATE_USERS;
 import static android.Manifest.permission.DUMP;
+import static android.Manifest.permission.GET_INTENT_SENDER_INTENT;
 import static android.Manifest.permission.LOCAL_MAC_ADDRESS;
 import static android.Manifest.permission.NETWORK_FACTORY;
 import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.Manifest.permission.NETWORK_STACK;
+import static android.Manifest.permission.PACKET_KEEPALIVE_OFFLOAD;
 import static android.app.PendingIntent.FLAG_IMMUTABLE;
 import static android.content.Intent.ACTION_PACKAGE_ADDED;
 import static android.content.Intent.ACTION_PACKAGE_REMOVED;
@@ -134,6 +139,7 @@
 import static com.android.testutils.MiscAsserts.assertRunsInAtMost;
 import static com.android.testutils.MiscAsserts.assertSameElements;
 import static com.android.testutils.MiscAsserts.assertThrows;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -259,6 +265,7 @@
 import android.net.shared.PrivateDnsConfig;
 import android.net.util.MultinetworkPolicyTracker;
 import android.os.BadParcelableException;
+import android.os.BatteryStatsManager;
 import android.os.Binder;
 import android.os.Build;
 import android.os.Bundle;
@@ -297,6 +304,7 @@
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.connectivity.resources.R;
+import com.android.internal.app.IBatteryStats;
 import com.android.internal.net.VpnConfig;
 import com.android.internal.net.VpnProfile;
 import com.android.internal.util.ArrayUtils;
@@ -305,6 +313,7 @@
 import com.android.internal.util.test.FakeSettingsProvider;
 import com.android.net.module.util.ArrayTrackRecord;
 import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.LocationPermissionChecker;
 import com.android.server.ConnectivityService.ConnectivityDiagnosticsCallbackInfo;
 import com.android.server.ConnectivityService.NetworkRequestInfo;
 import com.android.server.connectivity.MockableSystemProperties;
@@ -488,6 +497,11 @@
     @Mock Resources mResources;
     @Mock ProxyTracker mProxyTracker;
 
+    // BatteryStatsManager is final and cannot be mocked with regular mockito, so just mock the
+    // underlying binder calls.
+    final BatteryStatsManager mBatteryStatsManager =
+            new BatteryStatsManager(mock(IBatteryStats.class));
+
     private ArgumentCaptor<ResolverParamsParcel> mResolverParamsParcelCaptor =
             ArgumentCaptor.forClass(ResolverParamsParcel.class);
 
@@ -579,6 +593,7 @@
             if (Context.NETWORK_POLICY_SERVICE.equals(name)) return mNetworkPolicyManager;
             if (Context.SYSTEM_CONFIG_SERVICE.equals(name)) return mSystemConfigManager;
             if (Context.NETWORK_STATS_SERVICE.equals(name)) return mStatsManager;
+            if (Context.BATTERY_STATS_SERVICE.equals(name)) return mBatteryStatsManager;
             return super.getSystemService(name);
         }
 
@@ -659,6 +674,15 @@
         public void setPermission(String permission, Integer granted) {
             mMockedPermissions.put(permission, granted);
         }
+
+        @Override
+        public Intent registerReceiverForAllUsers(@Nullable BroadcastReceiver receiver,
+                @NonNull IntentFilter filter, @Nullable String broadcastPermission,
+                @Nullable Handler scheduler) {
+            // TODO: ensure MultinetworkPolicyTracker's BroadcastReceiver is tested; just returning
+            // null should not pass the test
+            return null;
+        }
     }
 
     private void waitForIdle() {
@@ -1208,7 +1232,24 @@
                             return mDeviceIdleInternal;
                         }
                     },
-                    mNetworkManagementService, mMockNetd, userId, mVpnProfileStore);
+                    mNetworkManagementService, mMockNetd, userId, mVpnProfileStore,
+                    new SystemServices(mServiceContext) {
+                        @Override
+                        public String settingsSecureGetStringForUser(String key, int userId) {
+                            switch (key) {
+                                // Settings keys not marked as @Readable are not readable from
+                                // non-privileged apps, unless marked as testOnly=true
+                                // (atest refuses to install testOnly=true apps), even if mocked
+                                // in the content provider, because
+                                // Settings.Secure.NameValueCache#getStringForUser checks the key
+                                // before querying the mock settings provider.
+                                case Settings.Secure.ALWAYS_ON_VPN_APP:
+                                    return null;
+                                default:
+                                    return super.settingsSecureGetStringForUser(key, userId);
+                            }
+                        }
+                    }, new Ikev2SessionCreator());
         }
 
         public void setUids(Set<UidRange> uids) {
@@ -1455,8 +1496,7 @@
         return mService.getNetworkAgentInfoForNetwork(mna.getNetwork()).clatd;
     }
 
-    private static class WrappedMultinetworkPolicyTracker extends MultinetworkPolicyTracker {
-        volatile boolean mConfigRestrictsAvoidBadWifi;
+    private class WrappedMultinetworkPolicyTracker extends MultinetworkPolicyTracker {
         volatile int mConfigMeteredMultipathPreference;
 
         WrappedMultinetworkPolicyTracker(Context c, Handler h, Runnable r) {
@@ -1464,8 +1504,8 @@
         }
 
         @Override
-        public boolean configRestrictsAvoidBadWifi() {
-            return mConfigRestrictsAvoidBadWifi;
+        protected Resources getResourcesForActiveSubId() {
+            return mResources;
         }
 
         @Override
@@ -1592,6 +1632,11 @@
         mServiceContext = new MockContext(InstrumentationRegistry.getContext(),
                 new FakeSettingsProvider());
         mServiceContext.setUseRegisteredHandlers(true);
+        mServiceContext.setPermission(NETWORK_FACTORY, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(CONTROL_OEM_PAID_NETWORK_PREFERENCE, PERMISSION_GRANTED);
+        mServiceContext.setPermission(PACKET_KEEPALIVE_OFFLOAD, PERMISSION_GRANTED);
+        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
 
         mAlarmManagerThread = new HandlerThread("TestAlarmManager");
         mAlarmManagerThread.start();
@@ -1651,6 +1696,13 @@
             return mPolicyTracker;
         }).when(deps).makeMultinetworkPolicyTracker(any(), any(), any());
         doReturn(true).when(deps).getCellular464XlatEnabled();
+        doAnswer(inv ->
+            new LocationPermissionChecker(inv.getArgument(0)) {
+                @Override
+                protected int getCurrentUser() {
+                    return runAsShell(CREATE_USERS, () -> super.getCurrentUser());
+                }
+            }).when(deps).makeLocationPermissionChecker(any());
 
         doReturn(60000).when(mResources).getInteger(R.integer.config_networkTransitionTimeout);
         doReturn("").when(mResources).getString(R.string.config_networkCaptivePortalServerUrl);
@@ -1670,7 +1722,9 @@
                 .getIdentifier(eq("config_networkSupportedKeepaliveCount"), eq("array"), any());
         doReturn(R.array.network_switch_type_name).when(mResources)
                 .getIdentifier(eq("network_switch_type_name"), eq("array"), any());
-
+        doReturn(R.integer.config_networkAvoidBadWifi).when(mResources)
+                .getIdentifier(eq("config_networkAvoidBadWifi"), eq("integer"), any());
+        doReturn(1).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
 
         final ConnectivityResources connRes = mock(ConnectivityResources.class);
         doReturn(mResources).when(connRes).get();
@@ -1680,6 +1734,12 @@
         doReturn(mResources).when(mockResContext).getResources();
         ConnectivityResources.setResourcesContextForTest(mockResContext);
 
+        doAnswer(inv -> {
+            final PendingIntent a = inv.getArgument(0);
+            final PendingIntent b = inv.getArgument(1);
+            return runAsShell(GET_INTENT_SENDER_INTENT, () -> a.intentFilterEquals(b));
+        }).when(deps).intentFilterEquals(any(), any());
+
         return deps;
     }
 
@@ -4704,30 +4764,29 @@
     }
 
     @Test
-    public void testAvoidBadWifiSetting() throws Exception {
+    public void testSetAllowBadWifiUntil() throws Exception {
+        runAsShell(NETWORK_SETTINGS,
+                () -> mService.setTestAllowBadWifiUntil(System.currentTimeMillis() + 5_000L));
+        waitForIdle();
+        testAvoidBadWifiConfig_controlledBySettings();
+
+        runAsShell(NETWORK_SETTINGS,
+                () -> mService.setTestAllowBadWifiUntil(System.currentTimeMillis() - 5_000L));
+        waitForIdle();
+        testAvoidBadWifiConfig_ignoreSettings();
+    }
+
+    private void testAvoidBadWifiConfig_controlledBySettings() {
         final ContentResolver cr = mServiceContext.getContentResolver();
         final String settingName = ConnectivitySettingsManager.NETWORK_AVOID_BAD_WIFI;
 
-        mPolicyTracker.mConfigRestrictsAvoidBadWifi = false;
-        String[] values = new String[] {null, "0", "1"};
-        for (int i = 0; i < values.length; i++) {
-            Settings.Global.putInt(cr, settingName, 1);
-            mPolicyTracker.reevaluate();
-            waitForIdle();
-            String msg = String.format("config=false, setting=%s", values[i]);
-            assertTrue(mService.avoidBadWifi());
-            assertFalse(msg, mPolicyTracker.shouldNotifyWifiUnvalidated());
-        }
-
-        mPolicyTracker.mConfigRestrictsAvoidBadWifi = true;
-
-        Settings.Global.putInt(cr, settingName, 0);
+        Settings.Global.putString(cr, settingName, "0");
         mPolicyTracker.reevaluate();
         waitForIdle();
         assertFalse(mService.avoidBadWifi());
         assertFalse(mPolicyTracker.shouldNotifyWifiUnvalidated());
 
-        Settings.Global.putInt(cr, settingName, 1);
+        Settings.Global.putString(cr, settingName, "1");
         mPolicyTracker.reevaluate();
         waitForIdle();
         assertTrue(mService.avoidBadWifi());
@@ -4740,13 +4799,40 @@
         assertTrue(mPolicyTracker.shouldNotifyWifiUnvalidated());
     }
 
+    private void testAvoidBadWifiConfig_ignoreSettings() {
+        final ContentResolver cr = mServiceContext.getContentResolver();
+        final String settingName = ConnectivitySettingsManager.NETWORK_AVOID_BAD_WIFI;
+
+        String[] values = new String[] {null, "0", "1"};
+        for (int i = 0; i < values.length; i++) {
+            Settings.Global.putString(cr, settingName, values[i]);
+            mPolicyTracker.reevaluate();
+            waitForIdle();
+            String msg = String.format("config=false, setting=%s", values[i]);
+            assertTrue(mService.avoidBadWifi());
+            assertFalse(msg, mPolicyTracker.shouldNotifyWifiUnvalidated());
+        }
+    }
+
+    @Test
+    public void testAvoidBadWifiSetting() throws Exception {
+        final ContentResolver cr = mServiceContext.getContentResolver();
+        final String settingName = ConnectivitySettingsManager.NETWORK_AVOID_BAD_WIFI;
+
+        doReturn(1).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
+        testAvoidBadWifiConfig_ignoreSettings();
+
+        doReturn(0).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
+        testAvoidBadWifiConfig_controlledBySettings();
+    }
+
     @Ignore("Refactoring in progress b/178071397")
     @Test
     public void testAvoidBadWifi() throws Exception {
         final ContentResolver cr = mServiceContext.getContentResolver();
 
         // Pretend we're on a carrier that restricts switching away from bad wifi.
-        mPolicyTracker.mConfigRestrictsAvoidBadWifi = true;
+        doReturn(0).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
 
         // File a request for cell to ensure it doesn't go down.
         final TestNetworkCallback cellNetworkCallback = new TestNetworkCallback();
@@ -4797,13 +4883,13 @@
 
         // Simulate switching to a carrier that does not restrict avoiding bad wifi, and expect
         // that we switch back to cell.
-        mPolicyTracker.mConfigRestrictsAvoidBadWifi = false;
+        doReturn(1).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
         mPolicyTracker.reevaluate();
         defaultCallback.expectAvailableCallbacksValidated(mCellNetworkAgent);
         assertEquals(mCm.getActiveNetwork(), cellNetwork);
 
         // Switch back to a restrictive carrier.
-        mPolicyTracker.mConfigRestrictsAvoidBadWifi = true;
+        doReturn(0).when(mResources).getInteger(R.integer.config_networkAvoidBadWifi);
         mPolicyTracker.reevaluate();
         defaultCallback.expectAvailableCallbacksUnvalidated(mWiFiNetworkAgent);
         assertEquals(mCm.getActiveNetwork(), wifiNetwork);
@@ -9357,8 +9443,7 @@
         mServiceContext.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
                 PERMISSION_DENIED);
         mServiceContext.setPermission(NETWORK_SETTINGS, PERMISSION_DENIED);
-        mServiceContext.setPermission(Manifest.permission.NETWORK_STACK,
-                PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
         mServiceContext.setPermission(Manifest.permission.NETWORK_SETUP_WIZARD,
                 PERMISSION_DENIED);
     }
@@ -9799,7 +9884,7 @@
         setupConnectionOwnerUid(vpnOwnerUid, vpnType);
 
         // Test as VPN app
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
         mServiceContext.setPermission(
                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, PERMISSION_DENIED);
     }
@@ -9839,8 +9924,7 @@
     public void testGetConnectionOwnerUidVpnServiceNetworkStackDoesNotThrow() throws Exception {
         final int myUid = Process.myUid();
         setupConnectionOwnerUid(myUid, VpnManager.TYPE_VPN_SERVICE);
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         assertEquals(42, mService.getConnectionOwnerUid(getTestConnectionInfo()));
     }
@@ -10008,8 +10092,7 @@
     public void testCheckConnectivityDiagnosticsPermissionsNetworkStack() throws Exception {
         final NetworkAgentInfo naiWithoutUid = fakeMobileNai(new NetworkCapabilities());
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
         assertTrue(
                 "NetworkStack permission not applied",
                 mService.checkConnectivityDiagnosticsPermissions(
@@ -10025,7 +10108,7 @@
         nc.setAdministratorUids(new int[] {wrongUid});
         final NetworkAgentInfo naiWithUid = fakeWifiNai(nc);
 
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertFalse(
                 "Mismatched uid/package name should not pass the location permission check",
@@ -10035,7 +10118,7 @@
 
     private void verifyConnectivityDiagnosticsPermissionsWithNetworkAgentInfo(
             NetworkAgentInfo info, boolean expectPermission) {
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertEquals(
                 "Unexpected ConnDiags permission",
@@ -10103,7 +10186,7 @@
 
         setupLocationPermissions(Build.VERSION_CODES.Q, true, AppOpsManager.OPSTR_FINE_LOCATION,
                 Manifest.permission.ACCESS_FINE_LOCATION);
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertTrue(
                 "NetworkCapabilities administrator uid permission not applied",
@@ -10120,7 +10203,7 @@
 
         setupLocationPermissions(Build.VERSION_CODES.Q, true, AppOpsManager.OPSTR_FINE_LOCATION,
                 Manifest.permission.ACCESS_FINE_LOCATION);
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         // Use wrong pid and uid
         assertFalse(
@@ -10146,8 +10229,7 @@
         final NetworkRequest request = new NetworkRequest.Builder().build();
         when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         mService.registerConnectivityDiagnosticsCallback(
                 mConnectivityDiagnosticsCallback, request, mContext.getPackageName());
@@ -10166,8 +10248,7 @@
         final NetworkRequest request = new NetworkRequest.Builder().build();
         when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         mService.registerConnectivityDiagnosticsCallback(
                 mConnectivityDiagnosticsCallback, request, mContext.getPackageName());
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index c32c1d2..3566aef 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -16,12 +16,17 @@
 
 package com.android.server.net;
 
+import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY;
+import static android.Manifest.permission.UPDATE_DEVICE_STATS;
 import static android.content.Intent.ACTION_UID_REMOVED;
 import static android.content.Intent.EXTRA_UID;
+import static android.content.pm.PackageManager.PERMISSION_DENIED;
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
 import static android.net.ConnectivityManager.TYPE_WIFI;
 import static android.net.NetworkIdentity.OEM_PAID;
 import static android.net.NetworkIdentity.OEM_PRIVATE;
+import static android.net.NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK;
 import static android.net.NetworkStats.DEFAULT_NETWORK_ALL;
 import static android.net.NetworkStats.DEFAULT_NETWORK_NO;
 import static android.net.NetworkStats.DEFAULT_NETWORK_YES;
@@ -106,6 +111,7 @@
 import android.provider.Settings;
 import android.telephony.TelephonyManager;
 
+import androidx.annotation.Nullable;
 import androidx.test.InstrumentationRegistry;
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -200,6 +206,26 @@
             if (Context.TELEPHONY_SERVICE.equals(name)) return mTelephonyManager;
             return mBaseContext.getSystemService(name);
         }
+
+        @Override
+        public void enforceCallingOrSelfPermission(String permission, @Nullable String message) {
+            if (checkCallingOrSelfPermission(permission) != PERMISSION_GRANTED) {
+                throw new SecurityException("Test does not have mocked permission " + permission);
+            }
+        }
+
+        @Override
+        public int checkCallingOrSelfPermission(String permission) {
+            switch (permission) {
+                case PERMISSION_MAINLINE_NETWORK_STACK:
+                case READ_NETWORK_USAGE_HISTORY:
+                case UPDATE_DEVICE_STATS:
+                    return PERMISSION_GRANTED;
+                default:
+                    return PERMISSION_DENIED;
+            }
+
+        }
     }
 
     private final Clock mClock = new SimpleClock(ZoneOffset.UTC) {