Merge "Check location permission for ConnDiags last." into sc-dev
diff --git a/framework/src/android/net/ConnectivityDiagnosticsManager.java b/framework/src/android/net/ConnectivityDiagnosticsManager.java
index 3598ebc..dcc8a5e 100644
--- a/framework/src/android/net/ConnectivityDiagnosticsManager.java
+++ b/framework/src/android/net/ConnectivityDiagnosticsManager.java
@@ -713,7 +713,9 @@
      * <p>Callbacks registered by apps not meeting the above criteria will not be invoked.
      *
      * <p>If a registering app loses its relevant permissions, any callbacks it registered will
-     * silently stop receiving callbacks.
+     * silently stop receiving callbacks. Note that registering apps must also have location
+     * permissions to receive callbacks as some Networks may be location-bound (such as WiFi
+     * networks).
      *
      * <p>Each register() call <b>MUST</b> use a ConnectivityDiagnosticsCallback instance that is
      * not currently registered. If a ConnectivityDiagnosticsCallback instance is registered with
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 29a4856..5c47f27 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -9162,6 +9162,34 @@
         return results;
     }
 
+    private boolean hasLocationPermission(String packageName, int uid) {
+        // LocationPermissionChecker#checkLocationPermission can throw SecurityException if the uid
+        // and package name don't match. Throwing on the CS thread is not acceptable, so wrap the
+        // call in a try-catch.
+        try {
+            if (!mLocationPermissionChecker.checkLocationPermission(
+                        packageName, null /* featureId */, uid, null /* message */)) {
+                return false;
+            }
+        } catch (SecurityException e) {
+            return false;
+        }
+
+        return true;
+    }
+
+    private boolean ownsVpnRunningOverNetwork(int uid, Network network) {
+        for (NetworkAgentInfo virtual : mNetworkAgentInfos) {
+            if (virtual.supportsUnderlyingNetworks()
+                    && virtual.networkCapabilities.getOwnerUid() == uid
+                    && CollectionUtils.contains(virtual.declaredUnderlyingNetworks, network)) {
+                return true;
+            }
+        }
+
+        return false;
+    }
+
     @VisibleForTesting
     boolean checkConnectivityDiagnosticsPermissions(
             int callbackPid, int callbackUid, NetworkAgentInfo nai, String callbackPackageName) {
@@ -9169,29 +9197,14 @@
             return true;
         }
 
-        // LocationPermissionChecker#checkLocationPermission can throw SecurityException if the uid
-        // and package name don't match. Throwing on the CS thread is not acceptable, so wrap the
-        // call in a try-catch.
-        try {
-            if (!mLocationPermissionChecker.checkLocationPermission(
-                    callbackPackageName, null /* featureId */, callbackUid, null /* message */)) {
-                return false;
-            }
-        } catch (SecurityException e) {
+        // Administrator UIDs also contains the Owner UID
+        final int[] administratorUids = nai.networkCapabilities.getAdministratorUids();
+        if (!CollectionUtils.contains(administratorUids, callbackUid)
+                && !ownsVpnRunningOverNetwork(callbackUid, nai.network)) {
             return false;
         }
 
-        for (NetworkAgentInfo virtual : mNetworkAgentInfos) {
-            if (virtual.supportsUnderlyingNetworks()
-                    && virtual.networkCapabilities.getOwnerUid() == callbackUid
-                    && CollectionUtils.contains(virtual.declaredUnderlyingNetworks, nai.network)) {
-                return true;
-            }
-        }
-
-        // Administrator UIDs also contains the Owner UID
-        final int[] administratorUids = nai.networkCapabilities.getAdministratorUids();
-        return CollectionUtils.contains(administratorUids, callbackUid);
+        return hasLocationPermission(callbackPackageName, callbackUid);
     }
 
     @Override
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 29a411e..4661385 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -9940,28 +9940,32 @@
 
     @Test
     public void testCheckConnectivityDiagnosticsPermissionsWrongUidPackageName() throws Exception {
-        final NetworkAgentInfo naiWithoutUid = fakeMobileNai(new NetworkCapabilities());
+        final int wrongUid = Process.myUid() + 1;
+
+        final NetworkCapabilities nc = new NetworkCapabilities();
+        nc.setAdministratorUids(new int[] {wrongUid});
+        final NetworkAgentInfo naiWithUid = fakeMobileNai(nc);
 
         mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
 
         assertFalse(
                 "Mismatched uid/package name should not pass the location permission check",
                 mService.checkConnectivityDiagnosticsPermissions(
-                        Process.myPid() + 1, Process.myUid() + 1, naiWithoutUid,
-                        mContext.getOpPackageName()));
+                        Process.myPid() + 1, wrongUid, naiWithUid, mContext.getOpPackageName()));
     }
 
     @Test
     public void testCheckConnectivityDiagnosticsPermissionsNoLocationPermission() throws Exception {
-        final NetworkAgentInfo naiWithoutUid = fakeMobileNai(new NetworkCapabilities());
+        final NetworkCapabilities nc = new NetworkCapabilities();
+        nc.setAdministratorUids(new int[] {Process.myUid()});
+        final NetworkAgentInfo naiWithUid = fakeMobileNai(nc);
 
         mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
 
         assertFalse(
                 "ACCESS_FINE_LOCATION permission necessary for Connectivity Diagnostics",
                 mService.checkConnectivityDiagnosticsPermissions(
-                        Process.myPid(), Process.myUid(), naiWithoutUid,
-                        mContext.getOpPackageName()));
+                        Process.myPid(), Process.myUid(), naiWithUid, mContext.getOpPackageName()));
     }
 
     @Test