Inform ConnectivityService about always-on VPN lockdown.

Currently, when an always-on VPN is set in lockdown mode, Vpn
configures prohibit UID rules in netd directly and does not
inform ConnectivityService of the fact.

This means that ConnectivityService cannot send NetworkCallbacks
that tells apps that they are blocked or unblocked. It also means
that ConnectivityService has to take the mVpns lock and call into
Vpn to allow synchronous APIs such as getActiveNetwork to return
BLOCKED if the app is blocked.

Move all this to ConnectivityService:
- Add a setRequireVpnForUids API to ConnectivityManager, and have
  that pass the routing rules to netd.
- Update VpnTest to expect calls to ConnectivityManager instead
  of to netd.
- Whenever setRequireVpnForUids is called, ensure that
  ConnectivityService sends onBlockedStatusChanged to the
  affected callbacks.
- Update existing unit tests to check for callbacks.
- Add a way to find the VPN that applies to a given UID without
  taking the VPN lock, by instead scanning all connected VPNs.
  Use this as a replacement for direct access to mVpns.

For simplicity, and in order to ensure proper ordering between
the NetworkCallbacks sent for VPNs connecting and disconnecting,
process blocked UID ranges on the handler thread. This means that
when setRequireVpnForUids returns, the rule changes might not
have been applied. This shouldn't impact apps using network
connectivity, but it might mean that apps setting an always-on
package, and then immediately checking whether networking is
blocked, will see a behaviour change.

Bug: 173331190
Fix: 175670887
Test: new test coverage in ConnectivityServiceTest
Test: atest MixedDeviceOwnerTest#testAlwaysOnVpn \
            MixedDeviceOwnerTest#testAlwaysOnVpnLockDown \
	    MixedDeviceOwnerTest#testAlwaysOnVpnAcrossReboot \
	    MixedDeviceOwnerTest#testAlwaysOnVpnPackageUninstalled \
	    MixedDeviceOwnerTest#testAlwaysOnVpnUnsupportedPackage \
	    MixedDeviceOwnerTest#testAlwaysOnVpnUnsupportedPackageReplaced \
	    MixedDeviceOwnerTest#testAlwaysOnVpnPackageLogged \
            MixedProfileOwnerTest#testAlwaysOnVpn \
            MixedProfileOwnerTest#testAlwaysOnVpnLockDown \
	    MixedProfileOwnerTest#testAlwaysOnVpnAcrossReboot \
	    MixedProfileOwnerTest#testAlwaysOnVpnPackageUninstalled \
	    MixedProfileOwnerTest#testAlwaysOnVpnUnsupportedPackage \
	    MixedProfileOwnerTest#testAlwaysOnVpnUnsupportedPackageReplaced \
	    MixedProfileOwnerTest#testAlwaysOnVpnPackageLogged \
            MixedManagedProfileOwnerTest#testAlwaysOnVpn \
            MixedManagedProfileOwnerTest#testAlwaysOnVpnLockDown \
	    MixedManagedProfileOwnerTest#testAlwaysOnVpnAcrossReboot \
	    MixedManagedProfileOwnerTest#testAlwaysOnVpnPackageUninstalled \
	    MixedManagedProfileOwnerTest#testAlwaysOnVpnUnsupportedPackage \
	    MixedManagedProfileOwnerTest#testAlwaysOnVpnUnsupportedPackageReplaced \
	    MixedManagedProfileOwnerTest#testAlwaysOnVpnPackageLogged
Test: atest FrameworksNetTests HostsideVpnTests \
            CtsNetTestCases:VpnServiceTest \
	    CtsNetTestCases:Ikev2VpnTest
Change-Id: Iaca8a7cc343aef52706cff62a7735f338cb1b772
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 0d10e4a..06c1598 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -59,6 +59,7 @@
 import android.telephony.TelephonyManager;
 import android.util.ArrayMap;
 import android.util.Log;
+import android.util.Range;
 import android.util.SparseIntArray;
 
 import com.android.connectivity.aidl.INetworkAgent;
@@ -73,10 +74,12 @@
 import java.io.UncheckedIOException;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.Socket;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -1163,6 +1166,55 @@
     }
 
     /**
+     * Adds or removes a requirement for given UID ranges to use the VPN.
+     *
+     * If set to {@code true}, informs the system that the UIDs in the specified ranges must not
+     * have any connectivity except if a VPN is connected and applies to the UIDs, or if the UIDs
+     * otherwise have permission to bypass the VPN (e.g., because they have the
+     * {@link android.Manifest.permission.CONNECTIVITY_USE_RESTRICTED_NETWORKS} permission, or when
+     * using a socket protected by a method such as {@link VpnService#protect(DatagramSocket)}. If
+     * set to {@code false}, a previously-added restriction is removed.
+     * <p>
+     * Each of the UID ranges specified by this method is added and removed as is, and no processing
+     * is performed on the ranges to de-duplicate, merge, split, or intersect them. In order to
+     * remove a previously-added range, the exact range must be removed as is.
+     * <p>
+     * The changes are applied asynchronously and may not have been applied by the time the method
+     * returns. Apps will be notified about any changes that apply to them via
+     * {@link NetworkCallback#onBlockedStatusChanged} callbacks called after the changes take
+     * effect.
+     * <p>
+     * This method should be called only by the VPN code.
+     *
+     * @param ranges the UID ranges to restrict
+     * @param requireVpn whether the specified UID ranges must use a VPN
+     *
+     * TODO: expose as @SystemApi.
+     * @hide
+     */
+    @RequiresPermission(anyOf = {
+            NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
+            android.Manifest.permission.NETWORK_STACK})
+    public void setRequireVpnForUids(boolean requireVpn,
+            @NonNull Collection<Range<Integer>> ranges) {
+        Objects.requireNonNull(ranges);
+        // The Range class is not parcelable. Convert to UidRange, which is what is used internally.
+        // This method is not necessarily expected to be used outside the system server, so
+        // parceling may not be necessary, but it could be used out-of-process, e.g., by the network
+        // stack process, or by tests.
+        UidRange[] rangesArray = new UidRange[ranges.size()];
+        int index = 0;
+        for (Range<Integer> range : ranges) {
+            rangesArray[index++] = new UidRange(range.getLower(), range.getUpper());
+        }
+        try {
+            mService.setRequireVpnForUids(requireVpn, rangesArray);
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
+    /**
      * Returns details about the currently active default data network
      * for a given uid.  This is for internal use only to avoid spying
      * other apps.
diff --git a/core/java/android/net/IConnectivityManager.aidl b/core/java/android/net/IConnectivityManager.aidl
index b5dd947..b32c98b 100644
--- a/core/java/android/net/IConnectivityManager.aidl
+++ b/core/java/android/net/IConnectivityManager.aidl
@@ -29,6 +29,7 @@
 import android.net.NetworkState;
 import android.net.ISocketKeepaliveCallback;
 import android.net.ProxyInfo;
+import android.net.UidRange;
 import android.os.Bundle;
 import android.os.IBinder;
 import android.os.INetworkActivityListener;
@@ -146,6 +147,7 @@
     String getAlwaysOnVpnPackage(int userId);
     boolean isVpnLockdownEnabled(int userId);
     List<String> getVpnLockdownWhitelist(int userId);
+    void setRequireVpnForUids(boolean requireVpn, in UidRange[] ranges);
 
     void setProvisioningNotificationVisible(boolean visible, int networkType, in String action);
 
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 164ce61..0f677c4 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -547,6 +547,13 @@
     private static final int EVENT_CAPPORT_DATA_CHANGED = 46;
 
     /**
+     * Used by setRequireVpnForUids.
+     * arg1 = whether the specified UID ranges are required to use a VPN.
+     * obj  = Array of UidRange objects.
+     */
+    private static final int EVENT_SET_REQUIRE_VPN_FOR_UIDS = 47;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -1274,19 +1281,28 @@
         }
     }
 
-    private Network[] getVpnUnderlyingNetworks(int uid) {
-        synchronized (mVpns) {
-            if (!mLockdownEnabled) {
-                int user = UserHandle.getUserId(uid);
-                Vpn vpn = mVpns.get(user);
-                if (vpn != null && vpn.appliesToUid(uid)) {
-                    return vpn.getUnderlyingNetworks();
+    // TODO: determine what to do when more than one VPN applies to |uid|.
+    private NetworkAgentInfo getVpnForUid(int uid) {
+        synchronized (mNetworkForNetId) {
+            for (int i = 0; i < mNetworkForNetId.size(); i++) {
+                final NetworkAgentInfo nai = mNetworkForNetId.valueAt(i);
+                if (nai.isVPN() && nai.everConnected && nai.networkCapabilities.appliesToUid(uid)) {
+                    return nai;
                 }
             }
         }
         return null;
     }
 
+    private Network[] getVpnUnderlyingNetworks(int uid) {
+        synchronized (mVpns) {
+            if (mLockdownEnabled) return null;
+        }
+        final NetworkAgentInfo nai = getVpnForUid(uid);
+        if (nai != null) return nai.declaredUnderlyingNetworks;
+        return null;
+    }
+
     private NetworkState getUnfilteredActiveNetworkState(int uid) {
         NetworkAgentInfo nai = getDefaultNetwork();
 
@@ -1312,7 +1328,7 @@
     }
 
     /**
-     * Check if UID should be blocked from using the network with the given LinkProperties.
+     * Check if UID should be blocked from using the specified network.
      */
     private boolean isNetworkWithLinkPropertiesBlocked(LinkProperties lp, int uid,
             boolean ignoreBlocked) {
@@ -1320,12 +1336,7 @@
         if (ignoreBlocked) {
             return false;
         }
-        synchronized (mVpns) {
-            final Vpn vpn = mVpns.get(UserHandle.getUserId(uid));
-            if (vpn != null && vpn.getLockdown() && vpn.isBlockingUid(uid)) {
-                return true;
-            }
-        }
+        if (isUidBlockedByVpn(uid, mVpnBlockedUidRanges)) return true;
         final String iface = (lp == null ? "" : lp.getInterfaceName());
         return mPolicyManagerInternal.isUidNetworkingBlocked(uid, iface);
     }
@@ -1992,29 +2003,18 @@
     void handleRestrictBackgroundChanged(boolean restrictBackground) {
         if (mRestrictBackground == restrictBackground) return;
 
+        final List<UidRange> blockedRanges = mVpnBlockedUidRanges;
         for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
             final boolean curMetered = nai.networkCapabilities.isMetered();
             maybeNotifyNetworkBlocked(nai, curMetered, curMetered, mRestrictBackground,
-                    restrictBackground);
+                    restrictBackground, blockedRanges, blockedRanges);
         }
 
         mRestrictBackground = restrictBackground;
     }
 
-    private boolean isUidNetworkingWithVpnBlocked(int uid, int uidRules, boolean isNetworkMetered,
+    private boolean isUidBlockedByRules(int uid, int uidRules, boolean isNetworkMetered,
             boolean isBackgroundRestricted) {
-        synchronized (mVpns) {
-            final Vpn vpn = mVpns.get(UserHandle.getUserId(uid));
-            // Because the return value of this function depends on the list of UIDs the
-            // always-on VPN blocks when in lockdown mode, when the always-on VPN changes that
-            // list all state depending on the return value of this function has to be recomputed.
-            // TODO: add a trigger when the always-on VPN sets its blocked UIDs to reevaluate and
-            // send the necessary onBlockedStatusChanged callbacks.
-            if (vpn != null && vpn.getLockdown() && vpn.isBlockingUid(uid)) {
-                return true;
-            }
-        }
-
         return NetworkPolicyManagerInternal.isUidNetworkingBlocked(uid, uidRules,
                 isNetworkMetered, isBackgroundRestricted);
     }
@@ -4289,6 +4289,9 @@
                 case EVENT_DATA_SAVER_CHANGED:
                     handleRestrictBackgroundChanged(toBool(msg.arg1));
                     break;
+                case EVENT_SET_REQUIRE_VPN_FOR_UIDS:
+                    handleSetRequireVpnForUids(toBool(msg.arg1), (UidRange[]) msg.obj);
+                    break;
             }
         }
     }
@@ -4457,8 +4460,7 @@
         if (!nai.everConnected) {
             return;
         }
-        LinkProperties lp = getLinkProperties(nai);
-        if (isNetworkWithLinkPropertiesBlocked(lp, uid, false)) {
+        if (isNetworkWithLinkPropertiesBlocked(nai.linkProperties, uid, false)) {
             return;
         }
         nai.networkMonitor().forceReevaluation(uid);
@@ -4885,6 +4887,56 @@
         }
     }
 
+    private boolean isUidBlockedByVpn(int uid, List<UidRange> blockedUidRanges) {
+        // Determine whether this UID is blocked because of always-on VPN lockdown. If a VPN applies
+        // to the UID, then the UID is not blocked because always-on VPN lockdown applies only when
+        // a VPN is not up.
+        final NetworkAgentInfo vpnNai = getVpnForUid(uid);
+        if (vpnNai != null && !vpnNai.networkAgentConfig.allowBypass) return false;
+        for (UidRange range : blockedUidRanges) {
+            if (range.contains(uid)) return true;
+        }
+        return false;
+    }
+
+    @Override
+    public void setRequireVpnForUids(boolean requireVpn, UidRange[] ranges) {
+        NetworkStack.checkNetworkStackPermission(mContext);
+        mHandler.sendMessage(mHandler.obtainMessage(EVENT_SET_REQUIRE_VPN_FOR_UIDS,
+                encodeBool(requireVpn), 0 /* arg2 */, ranges));
+    }
+
+    private void handleSetRequireVpnForUids(boolean requireVpn, UidRange[] ranges) {
+        if (DBG) {
+            Log.d(TAG, "Setting VPN " + (requireVpn ? "" : "not ") + "required for UIDs: "
+                    + Arrays.toString(ranges));
+        }
+        // Cannot use a Set since the list of UID ranges might contain duplicates.
+        final List<UidRange> newVpnBlockedUidRanges = new ArrayList(mVpnBlockedUidRanges);
+        for (int i = 0; i < ranges.length; i++) {
+            if (requireVpn) {
+                newVpnBlockedUidRanges.add(ranges[i]);
+            } else {
+                newVpnBlockedUidRanges.remove(ranges[i]);
+            }
+        }
+
+        try {
+            mNetd.networkRejectNonSecureVpn(requireVpn, toUidRangeStableParcels(ranges));
+        } catch (RemoteException | ServiceSpecificException e) {
+            Log.e(TAG, "setRequireVpnForUids(" + requireVpn + ", "
+                    + Arrays.toString(ranges) + "): netd command failed: " + e);
+        }
+
+        for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
+            final boolean curMetered = nai.networkCapabilities.isMetered();
+            maybeNotifyNetworkBlocked(nai, curMetered, curMetered, mRestrictBackground,
+                    mRestrictBackground, mVpnBlockedUidRanges, newVpnBlockedUidRanges);
+        }
+
+        mVpnBlockedUidRanges = newVpnBlockedUidRanges;
+    }
+
     @Override
     public boolean updateLockdownVpn() {
         if (mDeps.getCallingUid() != Process.SYSTEM_UID) {
@@ -5870,6 +5922,12 @@
     // NOTE: Only should be accessed on ConnectivityServiceThread, except dump().
     private final ArraySet<NetworkAgentInfo> mNetworkAgentInfos = new ArraySet<>();
 
+    // UID ranges for users that are currently blocked by VPNs.
+    // This array is accessed and iterated on multiple threads without holding locks, so its
+    // contents must never be mutated. When the ranges change, the array is replaced with a new one
+    // (on the handler thread).
+    private volatile List<UidRange> mVpnBlockedUidRanges = new ArrayList<>();
+
     @GuardedBy("mBlockedAppUids")
     private final HashSet<Integer> mBlockedAppUids = new HashSet<>();
 
@@ -6524,7 +6582,7 @@
 
             if (meteredChanged) {
                 maybeNotifyNetworkBlocked(nai, oldMetered, newMetered, mRestrictBackground,
-                        mRestrictBackground);
+                        mRestrictBackground, mVpnBlockedUidRanges, mVpnBlockedUidRanges);
             }
 
             final boolean roamingChanged = prevNc.hasCapability(NET_CAPABILITY_NOT_ROAMING) !=
@@ -6589,6 +6647,15 @@
         return stableRanges;
     }
 
+    private static UidRangeParcel[] toUidRangeStableParcels(UidRange[] ranges) {
+        final UidRangeParcel[] stableRanges = new UidRangeParcel[ranges.length];
+        for (int i = 0; i < ranges.length; i++) {
+            stableRanges[i] = new UidRangeParcel(ranges[i].start, ranges[i].stop);
+        }
+        return stableRanges;
+    }
+
+
     private void updateUids(NetworkAgentInfo nai, NetworkCapabilities prevNc,
             NetworkCapabilities newNc) {
         Set<UidRange> prevRanges = null == prevNc ? null : prevNc.getUids();
@@ -7416,7 +7483,9 @@
         }
 
         final boolean metered = nai.networkCapabilities.isMetered();
-        final boolean blocked = isUidNetworkingWithVpnBlocked(nri.mUid, mUidRules.get(nri.mUid),
+        boolean blocked;
+        blocked = isUidBlockedByVpn(nri.mUid, mVpnBlockedUidRanges);
+        blocked |= isUidBlockedByRules(nri.mUid, mUidRules.get(nri.mUid),
                 metered, mRestrictBackground);
         callCallbackForRequest(nri, nai, ConnectivityManager.CALLBACK_AVAILABLE, blocked ? 1 : 0);
     }
@@ -7438,21 +7507,25 @@
      * @param newRestrictBackground True if data saver is enabled.
      */
     private void maybeNotifyNetworkBlocked(NetworkAgentInfo nai, boolean oldMetered,
-            boolean newMetered, boolean oldRestrictBackground, boolean newRestrictBackground) {
+            boolean newMetered, boolean oldRestrictBackground, boolean newRestrictBackground,
+            List<UidRange> oldBlockedUidRanges, List<UidRange> newBlockedUidRanges) {
 
         for (int i = 0; i < nai.numNetworkRequests(); i++) {
             NetworkRequest nr = nai.requestAt(i);
             NetworkRequestInfo nri = mNetworkRequests.get(nr);
             final int uidRules = mUidRules.get(nri.mUid);
-            final boolean oldBlocked, newBlocked;
-            // mVpns lock needs to be hold here to ensure that the active VPN cannot be changed
-            // between these two calls.
-            synchronized (mVpns) {
-                oldBlocked = isUidNetworkingWithVpnBlocked(nri.mUid, uidRules, oldMetered,
-                        oldRestrictBackground);
-                newBlocked = isUidNetworkingWithVpnBlocked(nri.mUid, uidRules, newMetered,
-                        newRestrictBackground);
-            }
+            final boolean oldBlocked, newBlocked, oldVpnBlocked, newVpnBlocked;
+
+            oldVpnBlocked = isUidBlockedByVpn(nri.mUid, oldBlockedUidRanges);
+            newVpnBlocked = (oldBlockedUidRanges != newBlockedUidRanges)
+                    ? isUidBlockedByVpn(nri.mUid, newBlockedUidRanges)
+                    : oldVpnBlocked;
+
+            oldBlocked = oldVpnBlocked || isUidBlockedByRules(nri.mUid, uidRules, oldMetered,
+                    oldRestrictBackground);
+            newBlocked = newVpnBlocked || isUidBlockedByRules(nri.mUid, uidRules, newMetered,
+                    newRestrictBackground);
+
             if (oldBlocked != newBlocked) {
                 callCallbackForRequest(nri, nai, ConnectivityManager.CALLBACK_BLK_CHANGED,
                         encodeBool(newBlocked));
@@ -7468,17 +7541,12 @@
     private void maybeNotifyNetworkBlockedForNewUidRules(int uid, int newRules) {
         for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
             final boolean metered = nai.networkCapabilities.isMetered();
+            final boolean vpnBlocked = isUidBlockedByVpn(uid, mVpnBlockedUidRanges);
             final boolean oldBlocked, newBlocked;
-            // TODO: Consider that doze mode or turn on/off battery saver would deliver lots of uid
-            // rules changed event. And this function actually loop through all connected nai and
-            // its requests. It seems that mVpns lock will be grabbed frequently in this case.
-            // Reduce the number of locking or optimize the use of lock are likely needed in future.
-            synchronized (mVpns) {
-                oldBlocked = isUidNetworkingWithVpnBlocked(
-                        uid, mUidRules.get(uid), metered, mRestrictBackground);
-                newBlocked = isUidNetworkingWithVpnBlocked(
-                        uid, newRules, metered, mRestrictBackground);
-            }
+            oldBlocked = vpnBlocked || isUidBlockedByRules(
+                    uid, mUidRules.get(uid), metered, mRestrictBackground);
+            newBlocked = vpnBlocked || isUidBlockedByRules(
+                    uid, newRules, metered, mRestrictBackground);
             if (oldBlocked == newBlocked) {
                 continue;
             }
diff --git a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
index 841c970..7bde4d5 100644
--- a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
@@ -146,7 +146,11 @@
     // Underlying networks declared by the agent. Only set if supportsUnderlyingNetworks is true.
     // The networks in this list might be declared by a VPN app using setUnderlyingNetworks and are
     // not guaranteed to be current or correct, or even to exist.
-    public @Nullable Network[] declaredUnderlyingNetworks;
+    //
+    // This array is read and iterated on multiple threads with no locking so its contents must
+    // never be modified. When the list of networks changes, replace with a new array, on the
+    // handler thread.
+    public @Nullable volatile Network[] declaredUnderlyingNetworks;
 
     // The capabilities originally announced by the NetworkAgent, regardless of any capabilities
     // that were added or removed due to this network's underlying networks.
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index efd852b..9421acd 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -299,6 +299,7 @@
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
 import kotlin.reflect.KClass;
 
@@ -412,6 +413,7 @@
 
         @Spy private Resources mResources;
         private final LinkedBlockingQueue<Intent> mStartedActivities = new LinkedBlockingQueue<>();
+
         // Map of permission name -> PermissionManager.Permission_{GRANTED|DENIED} constant
         private final HashMap<String, Integer> mMockedPermissions = new HashMap<>();
 
@@ -6509,6 +6511,26 @@
         checkNetworkInfo(mCm.getNetworkInfo(type), type, state);
     }
 
+    // Checks that each of the |agents| receive a blocked status change callback with the specified
+    // |blocked| value, in any order. This is needed because when an event affects multiple
+    // networks, ConnectivityService does not guarantee the order in which callbacks are fired.
+    private void assertBlockedCallbackInAnyOrder(TestNetworkCallback callback, boolean blocked,
+            TestNetworkAgentWrapper... agents) {
+        final List<Network> expectedNetworks = Arrays.asList(agents).stream()
+                .map((agent) -> agent.getNetwork())
+                .collect(Collectors.toList());
+
+        // Expect exactly one blocked callback for each agent.
+        for (int i = 0; i < agents.length; i++) {
+            CallbackEntry e = callback.expectCallbackThat(TIMEOUT_MS, (c) ->
+                    c instanceof CallbackEntry.BlockedStatus
+                            && ((CallbackEntry.BlockedStatus) c).getBlocked() == blocked);
+            Network network = e.getNetwork();
+            assertTrue("Received unexpected blocked callback for network " + network,
+                    expectedNetworks.remove(network));
+        }
+    }
+
     @Test
     public void testNetworkBlockedStatusAlwaysOnVpn() throws Exception {
         mServiceContext.setPermission(
@@ -6555,9 +6577,10 @@
         assertNetworkInfo(TYPE_WIFI, DetailedState.BLOCKED);
 
         // Disable lockdown, expect to see the network unblocked.
-        // There are no callbacks because they are not implemented yet.
         mService.setAlwaysOnVpnPackage(userId, null, false /* lockdown */, allowList);
         expectNetworkRejectNonSecureVpn(inOrder, false, firstHalf, secondHalf);
+        callback.expectBlockedStatusCallback(false, mWiFiNetworkAgent);
+        defaultCallback.expectBlockedStatusCallback(false, mWiFiNetworkAgent);
         vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
@@ -6605,6 +6628,8 @@
         allowList.clear();
         mService.setAlwaysOnVpnPackage(userId, ALWAYS_ON_PACKAGE, true /* lockdown */, allowList);
         expectNetworkRejectNonSecureVpn(inOrder, true, firstHalf, secondHalf);
+        defaultCallback.expectBlockedStatusCallback(true, mWiFiNetworkAgent);
+        assertBlockedCallbackInAnyOrder(callback, true, mWiFiNetworkAgent, mCellNetworkAgent);
         vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertNull(mCm.getActiveNetwork());
@@ -6614,6 +6639,8 @@
 
         // Disable lockdown. Everything is unblocked.
         mService.setAlwaysOnVpnPackage(userId, null, false /* lockdown */, allowList);
+        defaultCallback.expectBlockedStatusCallback(false, mWiFiNetworkAgent);
+        assertBlockedCallbackInAnyOrder(callback, false, mWiFiNetworkAgent, mCellNetworkAgent);
         vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
@@ -6647,6 +6674,8 @@
 
         // Enable lockdown and connect a VPN. The VPN is not blocked.
         mService.setAlwaysOnVpnPackage(userId, ALWAYS_ON_PACKAGE, true /* lockdown */, allowList);
+        defaultCallback.expectBlockedStatusCallback(true, mWiFiNetworkAgent);
+        assertBlockedCallbackInAnyOrder(callback, true, mWiFiNetworkAgent, mCellNetworkAgent);
         vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertNull(mCm.getActiveNetwork());
@@ -6658,7 +6687,7 @@
         defaultCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
         vpnUidCallback.assertNoCallback();  // vpnUidCallback has NOT_VPN capability.
         assertEquals(mMockVpn.getNetwork(), mCm.getActiveNetwork());
-        assertEquals(null, mCm.getActiveNetworkForUid(VPN_UID));  // BUG?
+        assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
         assertNetworkInfo(TYPE_MOBILE, DetailedState.DISCONNECTED);
         assertNetworkInfo(TYPE_VPN, DetailedState.CONNECTED);
diff --git a/tests/net/java/com/android/server/connectivity/VpnTest.java b/tests/net/java/com/android/server/connectivity/VpnTest.java
index cc47317..3648c4d 100644
--- a/tests/net/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/net/java/com/android/server/connectivity/VpnTest.java
@@ -27,7 +27,6 @@
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.AdditionalMatchers.aryEq;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -89,6 +88,7 @@
 import android.security.KeyStore;
 import android.util.ArrayMap;
 import android.util.ArraySet;
+import android.util.Range;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -350,7 +350,7 @@
 
         // Set always-on with lockdown.
         assertTrue(vpn.setAlwaysOnPackage(PKGS[1], true, null, mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.stop)
         }));
@@ -361,12 +361,11 @@
 
         // Switch to another app.
         assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true, null, mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.stop)
         }));
-
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[3] + 1, user.stop)
         }));
@@ -383,7 +382,7 @@
         // Set always-on with lockdown and allow app PKGS[2] from lockdown.
         assertTrue(vpn.setAlwaysOnPackage(
                 PKGS[1], true, Collections.singletonList(PKGS[2]), mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[2] + 1, user.stop)
         }));
@@ -392,10 +391,10 @@
         // Change allowed app list to PKGS[3].
         assertTrue(vpn.setAlwaysOnPackage(
                 PKGS[1], true, Collections.singletonList(PKGS[3]), mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[2] + 1, user.stop)
         }));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[3] + 1, user.stop)
         }));
@@ -405,11 +404,11 @@
         // Change the VPN app.
         assertTrue(vpn.setAlwaysOnPackage(
                 PKGS[0], true, Collections.singletonList(PKGS[3]), mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.start + PKG_UIDS[3] - 1)
         }));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[0] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[3] - 1)
         }));
@@ -418,11 +417,11 @@
 
         // Remove the list of allowed packages.
         assertTrue(vpn.setAlwaysOnPackage(PKGS[0], true, null, mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[3] + 1, user.stop)
         }));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.stop),
         }));
         assertBlocked(vpn, user.start + PKG_UIDS[1], user.start + PKG_UIDS[2],
@@ -432,10 +431,10 @@
         // Add the list of allowed packages.
         assertTrue(vpn.setAlwaysOnPackage(
                 PKGS[0], true, Collections.singletonList(PKGS[1]), mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.stop)
         }));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.stop)
         }));
@@ -450,11 +449,11 @@
         // allowed package should change from PGKS[1] to PKGS[2].
         assertTrue(vpn.setAlwaysOnPackage(
                 PKGS[0], true, Arrays.asList("com.foo.app", PKGS[2], "com.bar.app"), mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[]{
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[1] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[1] + 1, user.stop)
         }));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[]{
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[2] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[2] + 1, user.stop)
         }));
@@ -475,7 +474,7 @@
 
         // Set lockdown.
         assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true, null, mKeyStore));
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(user.start, user.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(user.start + PKG_UIDS[3] + 1, user.stop)
         }));
@@ -485,7 +484,7 @@
         // Add the restricted user.
         setMockedUsers(primaryUser, tempProfile);
         vpn.onUserAdded(tempProfile.id);
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(profile.start, profile.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(profile.start + PKG_UIDS[3] + 1, profile.stop)
         }));
@@ -493,7 +492,7 @@
         // Remove the restricted user.
         tempProfile.partial = true;
         vpn.onUserRemoved(tempProfile.id);
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(new UidRangeParcel[] {
+        verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(new UidRangeParcel[] {
                 new UidRangeParcel(profile.start, profile.start + PKG_UIDS[3] - 1),
                 new UidRangeParcel(profile.start + PKG_UIDS[3] + 1, profile.stop)
         }));
@@ -506,22 +505,29 @@
                 new UidRangeParcel(PRI_USER_RANGE.start, PRI_USER_RANGE.stop)};
         // Given legacy lockdown is already enabled,
         vpn.setLockdown(true);
-
-        verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(primaryUserRangeParcel));
+        verify(mConnectivityManager, times(1)).setRequireVpnForUids(true,
+                toRanges(primaryUserRangeParcel));
 
         // Enabling legacy lockdown twice should do nothing.
         vpn.setLockdown(true);
-        verify(mNetd, times(1))
-                .networkRejectNonSecureVpn(anyBoolean(), any(UidRangeParcel[].class));
+        verify(mConnectivityManager, times(1)).setRequireVpnForUids(anyBoolean(), any());
 
         // And disabling should remove the rules exactly once.
         vpn.setLockdown(false);
-        verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(primaryUserRangeParcel));
+        verify(mConnectivityManager, times(1)).setRequireVpnForUids(false,
+                toRanges(primaryUserRangeParcel));
 
         // Removing the lockdown again should have no effect.
         vpn.setLockdown(false);
-        verify(mNetd, times(2)).networkRejectNonSecureVpn(
-                anyBoolean(), any(UidRangeParcel[].class));
+        verify(mConnectivityManager, times(2)).setRequireVpnForUids(anyBoolean(), any());
+    }
+
+    private ArrayList<Range<Integer>> toRanges(UidRangeParcel[] ranges) {
+        ArrayList<Range<Integer>> rangesArray = new ArrayList<>(ranges.length);
+        for (int i = 0; i < ranges.length; i++) {
+            rangesArray.add(new Range<>(ranges[i].start, ranges[i].stop));
+        }
+        return rangesArray;
     }
 
     @Test
@@ -535,21 +541,21 @@
             new UidRangeParcel(entireUser[0].start + PKG_UIDS[0] + 1, entireUser[0].stop)
         };
 
-        final InOrder order = inOrder(mNetd);
+        final InOrder order = inOrder(mConnectivityManager);
 
         // Given lockdown is enabled with no package (legacy VPN),
         vpn.setLockdown(true);
-        order.verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(entireUser));
+        order.verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(entireUser));
 
         // When a new VPN package is set the rules should change to cover that package.
         vpn.prepare(null, PKGS[0], VpnManager.TYPE_VPN_SERVICE);
-        order.verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(entireUser));
-        order.verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(exceptPkg0));
+        order.verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(entireUser));
+        order.verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(exceptPkg0));
 
         // When that VPN package is unset, everything should be undone again in reverse.
         vpn.prepare(null, VpnConfig.LEGACY_VPN, VpnManager.TYPE_VPN_SERVICE);
-        order.verify(mNetd).networkRejectNonSecureVpn(eq(false), aryEq(exceptPkg0));
-        order.verify(mNetd).networkRejectNonSecureVpn(eq(true), aryEq(entireUser));
+        order.verify(mConnectivityManager).setRequireVpnForUids(false, toRanges(exceptPkg0));
+        order.verify(mConnectivityManager).setRequireVpnForUids(true, toRanges(entireUser));
     }
 
     @Test