ConnectivityManager: a simpler CallbackHandler

This patch simplifies CallbackHandler in the following way:
  - CallbackHandler directly uses the static references to
    sNetworkCallback and sCallbackRefCount. This allows to remove
    instance fields in CallbackHandler.
  - CallbackHandler does not have a reference to ConnectivityManager
    anymore
  - CallbackHandler.getObject() is now generic in a type-safe way.

Test: ConnectivityServiceTest passes
Bug: 28537383
Bug: 32130437
Change-Id: I5004da5b91498e6ff7f8b05057a9e24b975bb56e
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 9e5aaf5..db34124 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -2709,24 +2709,17 @@
     }
 
     private class CallbackHandler extends Handler {
-        private final HashMap<NetworkRequest, NetworkCallback>mCallbackMap;
-        private final AtomicInteger mRefCount;
         private static final String TAG = "ConnectivityManager.CallbackHandler";
-        private final ConnectivityManager mCm;
         private static final boolean DBG = false;
 
-        CallbackHandler(Looper looper, HashMap<NetworkRequest, NetworkCallback>callbackMap,
-                AtomicInteger refCount, ConnectivityManager cm) {
+        CallbackHandler(Looper looper) {
             super(looper);
-            mCallbackMap = callbackMap;
-            mRefCount = refCount;
-            mCm = cm;
         }
 
         @Override
         public void handleMessage(Message message) {
-            NetworkRequest request = (NetworkRequest) getObject(message, NetworkRequest.class);
-            Network network = (Network) getObject(message, Network.class);
+            NetworkRequest request = getObject(message, NetworkRequest.class);
+            Network network = getObject(message, Network.class);
             if (DBG) {
                 Log.d(TAG, whatToString(message.what) + " for network " + network);
             }
@@ -2769,9 +2762,7 @@
                 case CALLBACK_CAP_CHANGED: {
                     NetworkCallback callback = getCallback(request, "CAP_CHANGED");
                     if (callback != null) {
-                        NetworkCapabilities cap = (NetworkCapabilities)getObject(message,
-                                NetworkCapabilities.class);
-
+                        NetworkCapabilities cap = getObject(message, NetworkCapabilities.class);
                         callback.onCapabilitiesChanged(network, cap);
                     }
                     break;
@@ -2779,9 +2770,7 @@
                 case CALLBACK_IP_CHANGED: {
                     NetworkCallback callback = getCallback(request, "IP_CHANGED");
                     if (callback != null) {
-                        LinkProperties lp = (LinkProperties)getObject(message,
-                                LinkProperties.class);
-
+                        LinkProperties lp = getObject(message, LinkProperties.class);
                         callback.onLinkPropertiesChanged(network, lp);
                     }
                     break;
@@ -2802,12 +2791,12 @@
                 }
                 case CALLBACK_RELEASED: {
                     NetworkCallback callback = null;
-                    synchronized(mCallbackMap) {
-                        callback = mCallbackMap.remove(request);
+                    synchronized(sCallbacks) {
+                        callback = sCallbacks.remove(request);
                     }
                     if (callback != null) {
-                        synchronized(mRefCount) {
-                            if (mRefCount.decrementAndGet() == 0) {
+                        synchronized(sCallbackRefCount) {
+                            if (sCallbackRefCount.decrementAndGet() == 0) {
                                 getLooper().quit();
                             }
                         }
@@ -2828,14 +2817,14 @@
             }
         }
 
-        private Object getObject(Message msg, Class c) {
-            return msg.getData().getParcelable(c.getSimpleName());
+        private <T> T getObject(Message msg, Class<T> c) {
+            return (T) msg.getData().getParcelable(c.getSimpleName());
         }
 
         private NetworkCallback getCallback(NetworkRequest req, String name) {
             NetworkCallback callback;
-            synchronized(mCallbackMap) {
-                callback = mCallbackMap.get(req);
+            synchronized(sCallbacks) {
+                callback = sCallbacks.get(req);
             }
             if (callback == null) {
                 Log.e(TAG, "callback not found for " + name + " message");
@@ -2850,8 +2839,7 @@
                 // TODO: switch this to ConnectivityThread
                 HandlerThread callbackThread = new HandlerThread("ConnectivityManager");
                 callbackThread.start();
-                sCallbackHandler = new CallbackHandler(callbackThread.getLooper(),
-                        sNetworkCallback, sCallbackRefCount, this);
+                sCallbackHandler = new CallbackHandler(callbackThread.getLooper());
             }
         }
     }
@@ -2865,8 +2853,7 @@
         }
     }
 
-    static final HashMap<NetworkRequest, NetworkCallback> sNetworkCallback =
-            new HashMap<NetworkRequest, NetworkCallback>();
+    static final HashMap<NetworkRequest, NetworkCallback> sCallbacks = new HashMap<>();
     static final AtomicInteger sCallbackRefCount = new AtomicInteger(0);
     static CallbackHandler sCallbackHandler = null;
 
@@ -2882,25 +2869,32 @@
         if (need == null && action != REQUEST) {
             throw new IllegalArgumentException("null NetworkCapabilities");
         }
+        // TODO: throw an exception if networkCallback.networkRequest is not null.
+        // http://b/20701525
+        final NetworkRequest request;
         try {
             incCallbackHandlerRefCount();
-            synchronized(sNetworkCallback) {
+            synchronized(sCallbacks) {
+                Messenger messenger = new Messenger(sCallbackHandler);
+                Binder binder = new Binder();
                 if (action == LISTEN) {
-                    networkCallback.networkRequest = mService.listenForNetwork(need,
-                            new Messenger(sCallbackHandler), new Binder());
+                    request = mService.listenForNetwork(need, messenger, binder);
                 } else {
-                    networkCallback.networkRequest = mService.requestNetwork(need,
-                            new Messenger(sCallbackHandler), timeoutMs, new Binder(), legacyType);
+                    request = mService.requestNetwork(
+                            need, messenger, timeoutMs, binder, legacyType);
                 }
-                if (networkCallback.networkRequest != null) {
-                    sNetworkCallback.put(networkCallback.networkRequest, networkCallback);
+                if (request != null) {
+                    sCallbacks.put(request, networkCallback);
                 }
+                networkCallback.networkRequest = request;
             }
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
-        if (networkCallback.networkRequest == null) decCallbackHandlerRefCount();
-        return networkCallback.networkRequest;
+        if (request == null) {
+            decCallbackHandlerRefCount();
+        }
+        return request;
     }
 
     /**