Merge "Prevent NPEs when registering/unregistering ConnDiags CBs."
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index ea6156a..47bde72 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -9525,6 +9525,10 @@
             @NonNull IConnectivityDiagnosticsCallback callback,
             @NonNull NetworkRequest request,
             @NonNull String callingPackageName) {
+        Objects.requireNonNull(callback, "callback must not be null");
+        Objects.requireNonNull(request, "request must not be null");
+        Objects.requireNonNull(callingPackageName, "callingPackageName must not be null");
+
         if (request.legacyType != TYPE_NONE) {
             throw new IllegalArgumentException("ConnectivityManager.TYPE_* are deprecated."
                     + " Please use NetworkCapabilities instead.");
@@ -9573,6 +9577,9 @@
     @Override
     public void simulateDataStall(int detectionMethod, long timestampMillis,
             @NonNull Network network, @NonNull PersistableBundle extras) {
+        Objects.requireNonNull(network, "network must not be null");
+        Objects.requireNonNull(extras, "extras must not be null");
+
         enforceAnyPermissionOf(android.Manifest.permission.MANAGE_TEST_NETWORKS,
                 android.Manifest.permission.NETWORK_STACK);
         final NetworkCapabilities nc = getNetworkCapabilitiesInternal(network);
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 7034c92..c19ed61 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -204,6 +204,7 @@
 import android.location.LocationManager;
 import android.net.CaptivePortalData;
 import android.net.ConnectionInfo;
+import android.net.ConnectivityDiagnosticsManager.DataStallReport;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.ConnectivityManager.PacketKeepalive;
@@ -285,6 +286,7 @@
 import android.os.Parcel;
 import android.os.ParcelFileDescriptor;
 import android.os.Parcelable;
+import android.os.PersistableBundle;
 import android.os.Process;
 import android.os.RemoteException;
 import android.os.ServiceSpecificException;
@@ -10126,6 +10128,35 @@
         assertTrue(mService.mConnectivityDiagnosticsCallbacks.containsKey(mIBinder));
     }
 
+    @Test(expected = NullPointerException.class)
+    public void testRegisterConnectivityDiagnosticsCallbackNullCallback() {
+        mService.registerConnectivityDiagnosticsCallback(
+                null /* callback */,
+                new NetworkRequest.Builder().build(),
+                mContext.getPackageName());
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void testRegisterConnectivityDiagnosticsCallbackNullNetworkRequest() {
+        mService.registerConnectivityDiagnosticsCallback(
+                mConnectivityDiagnosticsCallback,
+                null /* request */,
+                mContext.getPackageName());
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void testRegisterConnectivityDiagnosticsCallbackNullPackageName() {
+        mService.registerConnectivityDiagnosticsCallback(
+                mConnectivityDiagnosticsCallback,
+                new NetworkRequest.Builder().build(),
+                null /* callingPackageName */);
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void testUnregisterConnectivityDiagnosticsCallbackNullPackageName() {
+        mService.unregisterConnectivityDiagnosticsCallback(null /* callback */);
+    }
+
     public NetworkAgentInfo fakeMobileNai(NetworkCapabilities nc) {
         final NetworkCapabilities cellNc = new NetworkCapabilities.Builder(nc)
                 .addTransportType(TRANSPORT_CELLULAR).build();
@@ -10436,6 +10467,24 @@
                                 areConnDiagCapsRedacted(report.getNetworkCapabilities())));
     }
 
+    @Test(expected = NullPointerException.class)
+    public void testSimulateDataStallNullNetwork() {
+        mService.simulateDataStall(
+                DataStallReport.DETECTION_METHOD_DNS_EVENTS,
+                0L /* timestampMillis */,
+                null /* network */,
+                new PersistableBundle());
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void testSimulateDataStallNullPersistableBundle() {
+        mService.simulateDataStall(
+                DataStallReport.DETECTION_METHOD_DNS_EVENTS,
+                0L /* timestampMillis */,
+                mock(Network.class),
+                null /* extras */);
+    }
+
     @Test
     public void testRouteAddDeleteUpdate() throws Exception {
         final NetworkRequest request = new NetworkRequest.Builder().build();