Merge "[DU09-0]Adding the NetworkStatsCollection Builder"
diff --git a/framework-t/src/android/app/usage/NetworkStatsManager.java b/framework-t/src/android/app/usage/NetworkStatsManager.java
index 8813f98..28f930f 100644
--- a/framework-t/src/android/app/usage/NetworkStatsManager.java
+++ b/framework-t/src/android/app/usage/NetworkStatsManager.java
@@ -21,6 +21,7 @@
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 
 import android.Manifest;
+import android.annotation.CallbackExecutor;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
@@ -39,14 +40,11 @@
 import android.net.NetworkStateSnapshot;
 import android.net.NetworkTemplate;
 import android.net.UnderlyingNetworkInfo;
+import android.net.netstats.IUsageCallback;
 import android.net.netstats.provider.INetworkStatsProviderCallback;
 import android.net.netstats.provider.NetworkStatsProvider;
-import android.os.Binder;
 import android.os.Build;
 import android.os.Handler;
-import android.os.Looper;
-import android.os.Message;
-import android.os.Messenger;
 import android.os.RemoteException;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
@@ -57,6 +55,7 @@
 
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.Executor;
 
 /**
  * Provides access to network usage history and statistics. Usage data is collected in
@@ -723,26 +722,36 @@
         }
     }
 
-    /** @hide */
-    public void registerUsageCallback(NetworkTemplate template, int networkType,
-            long thresholdBytes, UsageCallback callback, @Nullable Handler handler) {
+    /**
+     * Registers to receive notifications about data usage on specified networks.
+     *
+     * <p>The callbacks will continue to be called as long as the process is alive or
+     * {@link #unregisterUsageCallback} is called.
+     *
+     * @param template Template used to match networks. See {@link NetworkTemplate}.
+     * @param thresholdBytes Threshold in bytes to be notified on. The provided value that lower
+     *                       than 2MiB will be clamped for non-privileged callers.
+     * @param executor The executor on which callback will be invoked. The provided {@link Executor}
+     *                 must run callback sequentially, otherwise the order of callbacks cannot be
+     *                 guaranteed.
+     * @param callback The {@link UsageCallback} that the system will call when data usage
+     *                 has exceeded the specified threshold.
+     * @hide
+     */
+    @SystemApi(client = MODULE_LIBRARIES)
+    public void registerUsageCallback(@NonNull NetworkTemplate template, long thresholdBytes,
+            @NonNull @CallbackExecutor Executor executor, @NonNull UsageCallback callback) {
+        Objects.requireNonNull(template, "NetworkTemplate cannot be null");
         Objects.requireNonNull(callback, "UsageCallback cannot be null");
+        Objects.requireNonNull(executor, "Executor cannot be null");
 
-        final Looper looper;
-        if (handler == null) {
-            looper = Looper.myLooper();
-        } else {
-            looper = handler.getLooper();
-        }
-
-        DataUsageRequest request = new DataUsageRequest(DataUsageRequest.REQUEST_ID_UNSET,
+        final DataUsageRequest request = new DataUsageRequest(DataUsageRequest.REQUEST_ID_UNSET,
                 template, thresholdBytes);
         try {
-            CallbackHandler callbackHandler = new CallbackHandler(looper, networkType,
-                    template.getSubscriberId(), callback);
+            final UsageCallbackWrapper callbackWrapper =
+                    new UsageCallbackWrapper(executor, callback);
             callback.request = mService.registerUsageCallback(
-                    mContext.getOpPackageName(), request, new Messenger(callbackHandler),
-                    new Binder());
+                    mContext.getOpPackageName(), request, callbackWrapper);
             if (DBG) Log.d(TAG, "registerUsageCallback returned " + callback.request);
 
             if (callback.request == null) {
@@ -795,12 +804,15 @@
         NetworkTemplate template = createTemplate(networkType, subscriberId);
         if (DBG) {
             Log.d(TAG, "registerUsageCallback called with: {"
-                + " networkType=" + networkType
-                + " subscriberId=" + subscriberId
-                + " thresholdBytes=" + thresholdBytes
-                + " }");
+                    + " networkType=" + networkType
+                    + " subscriberId=" + subscriberId
+                    + " thresholdBytes=" + thresholdBytes
+                    + " }");
         }
-        registerUsageCallback(template, networkType, thresholdBytes, callback, handler);
+
+        final Executor executor = handler == null ? r -> r.run() : r -> handler.post(r);
+
+        registerUsageCallback(template, thresholdBytes, executor, callback);
     }
 
     /**
@@ -825,6 +837,26 @@
      * Base class for usage callbacks. Should be extended by applications wanting notifications.
      */
     public static abstract class UsageCallback {
+        /**
+         * Called when data usage has reached the given threshold.
+         *
+         * Called by {@code NetworkStatsService} when the registered threshold is reached.
+         * If a caller implements {@link #onThresholdReached(NetworkTemplate)}, the system
+         * will not call {@link #onThresholdReached(int, String)}.
+         *
+         * @param template The {@link NetworkTemplate} that associated with this callback.
+         * @hide
+         */
+        @SystemApi(client = MODULE_LIBRARIES)
+        public void onThresholdReached(@NonNull NetworkTemplate template) {
+            // Backward compatibility for those who didn't override this function.
+            final int networkType = networkTypeForTemplate(template);
+            if (networkType != ConnectivityManager.TYPE_NONE) {
+                final String subscriberId = template.getSubscriberIds().isEmpty() ? null
+                        : template.getSubscriberIds().iterator().next();
+                onThresholdReached(networkType, subscriberId);
+            }
+        }
 
         /**
          * Called when data usage has reached the given threshold.
@@ -835,6 +867,25 @@
          * @hide used for internal bookkeeping
          */
         private DataUsageRequest request;
+
+        /**
+         * Get network type from a template if feasible.
+         *
+         * @param template the target {@link NetworkTemplate}.
+         * @return legacy network type, only supports for the types which is already supported in
+         *         {@link #registerUsageCallback(int, String, long, UsageCallback, Handler)}.
+         *         {@link ConnectivityManager#TYPE_NONE} for other types.
+         */
+        private static int networkTypeForTemplate(@NonNull NetworkTemplate template) {
+            switch (template.getMatchRule()) {
+                case NetworkTemplate.MATCH_MOBILE:
+                    return ConnectivityManager.TYPE_MOBILE;
+                case NetworkTemplate.MATCH_WIFI:
+                    return ConnectivityManager.TYPE_WIFI;
+                default:
+                    return ConnectivityManager.TYPE_NONE;
+            }
+        }
     }
 
     /**
@@ -953,43 +1004,32 @@
         }
     }
 
-    private static class CallbackHandler extends Handler {
-        private final int mNetworkType;
-        private final String mSubscriberId;
-        private UsageCallback mCallback;
+    private static class UsageCallbackWrapper extends IUsageCallback.Stub {
+        // Null if unregistered.
+        private volatile UsageCallback mCallback;
 
-        CallbackHandler(Looper looper, int networkType, String subscriberId,
-                UsageCallback callback) {
-            super(looper);
-            mNetworkType = networkType;
-            mSubscriberId = subscriberId;
+        private final Executor mExecutor;
+
+        UsageCallbackWrapper(@NonNull Executor executor, @NonNull UsageCallback callback) {
             mCallback = callback;
+            mExecutor = executor;
         }
 
         @Override
-        public void handleMessage(Message message) {
-            DataUsageRequest request =
-                    (DataUsageRequest) getObject(message, DataUsageRequest.PARCELABLE_KEY);
-
-            switch (message.what) {
-                case CALLBACK_LIMIT_REACHED: {
-                    if (mCallback != null) {
-                        mCallback.onThresholdReached(mNetworkType, mSubscriberId);
-                    } else {
-                        Log.e(TAG, "limit reached with released callback for " + request);
-                    }
-                    break;
-                }
-                case CALLBACK_RELEASED: {
-                    if (DBG) Log.d(TAG, "callback released for " + request);
-                    mCallback = null;
-                    break;
-                }
+        public void onThresholdReached(DataUsageRequest request) {
+            // Copy it to a local variable in case mCallback changed inside the if condition.
+            final UsageCallback callback = mCallback;
+            if (callback != null) {
+                mExecutor.execute(() -> callback.onThresholdReached(request.template));
+            } else {
+                Log.e(TAG, "onThresholdReached with released callback for " + request);
             }
         }
 
-        private static Object getObject(Message msg, String key) {
-            return msg.getData().getParcelable(key);
+        @Override
+        public void onCallbackReleased(DataUsageRequest request) {
+            if (DBG) Log.d(TAG, "callback released for " + request);
+            mCallback = null;
         }
     }
 
diff --git a/framework-t/src/android/net/INetworkStatsService.aidl b/framework-t/src/android/net/INetworkStatsService.aidl
index da0aa99..efe626d 100644
--- a/framework-t/src/android/net/INetworkStatsService.aidl
+++ b/framework-t/src/android/net/INetworkStatsService.aidl
@@ -24,6 +24,7 @@
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
 import android.net.UnderlyingNetworkInfo;
+import android.net.netstats.IUsageCallback;
 import android.net.netstats.provider.INetworkStatsProvider;
 import android.net.netstats.provider.INetworkStatsProviderCallback;
 import android.os.IBinder;
@@ -71,7 +72,7 @@
 
     /** Registers a callback on data usage. */
     DataUsageRequest registerUsageCallback(String callingPackage,
-            in DataUsageRequest request, in Messenger messenger, in IBinder binder);
+            in DataUsageRequest request, in IUsageCallback callback);
 
     /** Unregisters a callback on data usage. */
     void unregisterUsageRequest(in DataUsageRequest request);
diff --git a/framework-t/src/android/net/TrafficStats.java b/framework-t/src/android/net/TrafficStats.java
index c803a72..c2f0cdf 100644
--- a/framework-t/src/android/net/TrafficStats.java
+++ b/framework-t/src/android/net/TrafficStats.java
@@ -16,6 +16,8 @@
 
 package android.net;
 
+import static android.annotation.SystemApi.Client.MODULE_LIBRARIES;
+
 import android.annotation.NonNull;
 import android.annotation.SuppressLint;
 import android.annotation.SystemApi;
@@ -29,6 +31,7 @@
 import android.os.Binder;
 import android.os.Build;
 import android.os.RemoteException;
+import android.util.Log;
 
 import com.android.server.NetworkManagementSocketTagger;
 
@@ -210,10 +213,29 @@
         }
         final NetworkStatsManager statsManager =
                 context.getSystemService(NetworkStatsManager.class);
+        if (statsManager == null) {
+            // TODO: Currently Process.isSupplemental is not working yet, because it depends on
+            //  process to run in a certain UID range, which is not true for now. Change this
+            //  to Log.wtf once Process.isSupplemental is ready.
+            Log.e(TAG, "TrafficStats not initialized, uid=" + Binder.getCallingUid());
+            return;
+        }
         sStatsService = statsManager.getBinder();
     }
 
     /**
+     * Attach the socket tagger implementation to the current process, to
+     * get notified when a socket's {@link FileDescriptor} is assigned to
+     * a thread. See {@link SocketTagger#set(SocketTagger)}.
+     *
+     * @hide
+     */
+    @SystemApi(client = MODULE_LIBRARIES)
+    public static void attachSocketTagger() {
+        NetworkManagementSocketTagger.install();
+    }
+
+    /**
      * Set active tag to use when accounting {@link Socket} traffic originating
      * from the current thread. Only one active tag per thread is supported.
      * <p>
diff --git a/framework-t/src/android/net/netstats/IUsageCallback.aidl b/framework-t/src/android/net/netstats/IUsageCallback.aidl
new file mode 100644
index 0000000..4e8a5b2
--- /dev/null
+++ b/framework-t/src/android/net/netstats/IUsageCallback.aidl
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.netstats;
+
+import android.net.DataUsageRequest;
+
+/**
+ * Interface for NetworkStatsService to notify events to the callers of registerUsageCallback.
+ *
+ * @hide
+ */
+oneway interface IUsageCallback {
+    void onThresholdReached(in DataUsageRequest request);
+    void onCallbackReleased(in DataUsageRequest request);
+}
diff --git a/service-t/src/com/android/server/net/NetworkStatsFactory.java b/service-t/src/com/android/server/net/NetworkStatsFactory.java
index bb123a3..17f3455 100644
--- a/service-t/src/com/android/server/net/NetworkStatsFactory.java
+++ b/service-t/src/com/android/server/net/NetworkStatsFactory.java
@@ -26,10 +26,10 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.INetd;
+import android.content.Context;
+import android.net.ConnectivityManager;
 import android.net.NetworkStats;
 import android.net.UnderlyingNetworkInfo;
-import android.os.RemoteException;
 import android.os.StrictMode;
 import android.os.SystemClock;
 
@@ -70,7 +70,7 @@
 
     private final boolean mUseBpfStats;
 
-    private final INetd mNetd;
+    private final Context mContext;
 
     /**
      * Guards persistent data access in this class
@@ -158,12 +158,12 @@
         NetworkStats.apply464xlatAdjustments(baseTraffic, stackedTraffic, mStackedIfaces);
     }
 
-    public NetworkStatsFactory(@NonNull INetd netd) {
-        this(new File("/proc/"), true, netd);
+    public NetworkStatsFactory(@NonNull Context ctx) {
+        this(ctx, new File("/proc/"), true);
     }
 
     @VisibleForTesting
-    public NetworkStatsFactory(File procRoot, boolean useBpfStats, @NonNull INetd netd) {
+    public NetworkStatsFactory(@NonNull Context ctx, File procRoot, boolean useBpfStats) {
         mStatsXtIfaceAll = new File(procRoot, "net/xt_qtaguid/iface_stat_all");
         mStatsXtIfaceFmt = new File(procRoot, "net/xt_qtaguid/iface_stat_fmt");
         mStatsXtUid = new File(procRoot, "net/xt_qtaguid/stats");
@@ -172,7 +172,7 @@
             mPersistSnapshot = new NetworkStats(SystemClock.elapsedRealtime(), -1);
             mTunAnd464xlatAdjustedStats = new NetworkStats(SystemClock.elapsedRealtime(), -1);
         }
-        mNetd = netd;
+        mContext = ctx;
     }
 
     public NetworkStats readBpfNetworkStatsDev() throws IOException {
@@ -295,11 +295,12 @@
     }
 
     @GuardedBy("mPersistentDataLock")
-    private void requestSwapActiveStatsMapLocked() throws RemoteException {
-        // Ask netd to do a active map stats swap. When the binder call successfully returns,
+    private void requestSwapActiveStatsMapLocked() {
+        // Do a active map stats swap. When the binder call successfully returns,
         // the system server should be able to safely read and clean the inactive map
         // without race problem.
-        mNetd.trafficSwapActiveStatsMap();
+        final ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
+        cm.swapActiveStatsMap();
     }
 
     /**
@@ -327,7 +328,7 @@
                 if (mUseBpfStats) {
                     try {
                         requestSwapActiveStatsMapLocked();
-                    } catch (RemoteException e) {
+                    } catch (RuntimeException e) {
                         throw new IOException(e);
                     }
                     // Stats are always read from the inactive map, so they must be read after the
diff --git a/service-t/src/com/android/server/net/NetworkStatsObservers.java b/service-t/src/com/android/server/net/NetworkStatsObservers.java
index b57a4f9..1953624 100644
--- a/service-t/src/com/android/server/net/NetworkStatsObservers.java
+++ b/service-t/src/com/android/server/net/NetworkStatsObservers.java
@@ -26,13 +26,12 @@
 import android.net.NetworkStatsCollection;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
-import android.os.Bundle;
+import android.net.netstats.IUsageCallback;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
 import android.os.Looper;
 import android.os.Message;
-import android.os.Messenger;
 import android.os.Process;
 import android.os.RemoteException;
 import android.util.ArrayMap;
@@ -75,10 +74,10 @@
      *
      * @return the normalized request wrapped within {@link RequestInfo}.
      */
-    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
-                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
-        DataUsageRequest request = buildRequest(inputRequest);
-        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
+    public DataUsageRequest register(DataUsageRequest inputRequest, IUsageCallback callback,
+            int callingUid, @NetworkStatsAccess.Level int accessLevel) {
+        DataUsageRequest request = buildRequest(inputRequest, callingUid);
+        RequestInfo requestInfo = buildRequestInfo(request, callback, callingUid,
                 accessLevel);
 
         if (LOGV) Log.v(TAG, "Registering observer for " + request);
@@ -195,10 +194,12 @@
         }
     }
 
-    private DataUsageRequest buildRequest(DataUsageRequest request) {
-        // Cap the minimum threshold to a safe default to avoid too many callbacks
-        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
-        if (thresholdInBytes < request.thresholdInBytes) {
+    private DataUsageRequest buildRequest(DataUsageRequest request, int callingUid) {
+        // For non-system uid, cap the minimum threshold to a safe default to avoid too
+        // many callbacks.
+        long thresholdInBytes = (callingUid == Process.SYSTEM_UID ? request.thresholdInBytes
+                : Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes));
+        if (thresholdInBytes > request.thresholdInBytes) {
             Log.w(TAG, "Threshold was too low for " + request
                     + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
         }
@@ -206,11 +207,10 @@
                 request.template, thresholdInBytes);
     }
 
-    private RequestInfo buildRequestInfo(DataUsageRequest request,
-                Messenger messenger, IBinder binder, int callingUid,
-                @NetworkStatsAccess.Level int accessLevel) {
+    private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback,
+            int callingUid, @NetworkStatsAccess.Level int accessLevel) {
         if (accessLevel <= NetworkStatsAccess.Level.USER) {
-            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
+            return new UserUsageRequestInfo(this, request, callback, callingUid,
                     accessLevel);
         } else {
             // Safety check in case a new access level is added and we forgot to update this
@@ -218,7 +218,7 @@
                 throw new IllegalArgumentException(
                         "accessLevel " + accessLevel + " is less than DEVICESUMMARY.");
             }
-            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
+            return new NetworkUsageRequestInfo(this, request, callback, callingUid,
                     accessLevel);
         }
     }
@@ -230,25 +230,23 @@
     private abstract static class RequestInfo implements IBinder.DeathRecipient {
         private final NetworkStatsObservers mStatsObserver;
         protected final DataUsageRequest mRequest;
-        private final Messenger mMessenger;
-        private final IBinder mBinder;
+        private final IUsageCallback mCallback;
         protected final int mCallingUid;
         protected final @NetworkStatsAccess.Level int mAccessLevel;
         protected NetworkStatsRecorder mRecorder;
         protected NetworkStatsCollection mCollection;
 
         RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                    Messenger messenger, IBinder binder, int callingUid,
+                IUsageCallback callback, int callingUid,
                     @NetworkStatsAccess.Level int accessLevel) {
             mStatsObserver = statsObserver;
             mRequest = request;
-            mMessenger = messenger;
-            mBinder = binder;
+            mCallback = callback;
             mCallingUid = callingUid;
             mAccessLevel = accessLevel;
 
             try {
-                mBinder.linkToDeath(this, 0);
+                mCallback.asBinder().linkToDeath(this, 0);
             } catch (RemoteException e) {
                 binderDied();
             }
@@ -257,7 +255,7 @@
         @Override
         public void binderDied() {
             if (LOGV) {
-                Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mBinder + ")");
+                Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mCallback + ")");
             }
             mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
             callCallback(NetworkStatsManager.CALLBACK_RELEASED);
@@ -270,9 +268,7 @@
         }
 
         private void unlinkDeathRecipient() {
-            if (mBinder != null) {
-                mBinder.unlinkToDeath(this, 0);
-            }
+            mCallback.asBinder().unlinkToDeath(this, 0);
         }
 
         /**
@@ -294,17 +290,19 @@
         }
 
         private void callCallback(int callbackType) {
-            Bundle bundle = new Bundle();
-            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
-            Message msg = Message.obtain();
-            msg.what = callbackType;
-            msg.setData(bundle);
             try {
                 if (LOGV) {
                     Log.v(TAG, "sending notification " + callbackTypeToName(callbackType)
                             + " for " + mRequest);
                 }
-                mMessenger.send(msg);
+                switch (callbackType) {
+                    case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
+                        mCallback.onThresholdReached(mRequest);
+                        break;
+                    case NetworkStatsManager.CALLBACK_RELEASED:
+                        mCallback.onCallbackReleased(mRequest);
+                        break;
+                }
             } catch (RemoteException e) {
                 // May occur naturally in the race of binder death.
                 Log.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
@@ -334,9 +332,9 @@
 
     private static class NetworkUsageRequestInfo extends RequestInfo {
         NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                    Messenger messenger, IBinder binder, int callingUid,
+                IUsageCallback callback, int callingUid,
                     @NetworkStatsAccess.Level int accessLevel) {
-            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
+            super(statsObserver, request, callback, callingUid, accessLevel);
         }
 
         @Override
@@ -376,9 +374,9 @@
 
     private static class UserUsageRequestInfo extends RequestInfo {
         UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                    Messenger messenger, IBinder binder, int callingUid,
+                    IUsageCallback callback, int callingUid,
                     @NetworkStatsAccess.Level int accessLevel) {
-            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
+            super(statsObserver, request, callback, callingUid, accessLevel);
         }
 
         @Override
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 4c78dcb..243d621 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -99,6 +99,7 @@
 import android.net.TrafficStats;
 import android.net.UnderlyingNetworkInfo;
 import android.net.Uri;
+import android.net.netstats.IUsageCallback;
 import android.net.netstats.provider.INetworkStatsProvider;
 import android.net.netstats.provider.INetworkStatsProviderCallback;
 import android.net.netstats.provider.NetworkStatsProvider;
@@ -110,7 +111,6 @@
 import android.os.IBinder;
 import android.os.Looper;
 import android.os.Message;
-import android.os.Messenger;
 import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.ServiceSpecificException;
@@ -422,7 +422,7 @@
         final NetworkStatsService service = new NetworkStatsService(context,
                 INetd.Stub.asInterface((IBinder) context.getSystemService(Context.NETD_SERVICE)),
                 alarmManager, wakeLock, getDefaultClock(),
-                new DefaultNetworkStatsSettings(), new NetworkStatsFactory(netd),
+                new DefaultNetworkStatsSettings(), new NetworkStatsFactory(context),
                 new NetworkStatsObservers(), getDefaultSystemDir(), getDefaultBaseDir(),
                 new Dependencies());
 
@@ -1000,8 +1000,17 @@
         }
 
         // TODO: switch to data layer stats once kernel exports
-        // for now, read network layer stats and flatten across all ifaces
-        final NetworkStats networkLayer = readNetworkStatsUidDetail(uid, INTERFACES_ALL, TAG_ALL);
+        // for now, read network layer stats and flatten across all ifaces.
+        // This function is used to query NeworkStats for calle's uid. The only caller method
+        // TrafficStats#getDataLayerSnapshotForUid alrady claim no special permission to query
+        // its own NetworkStats.
+        final long ident = Binder.clearCallingIdentity();
+        final NetworkStats networkLayer;
+        try {
+            networkLayer = readNetworkStatsUidDetail(uid, INTERFACES_ALL, TAG_ALL);
+        } finally {
+            Binder.restoreCallingIdentity(ident);
+        }
 
         // splice in operation counts
         networkLayer.spliceOperationsFrom(mUidOperations);
@@ -1148,21 +1157,20 @@
     }
 
     @Override
-    public DataUsageRequest registerUsageCallback(String callingPackage,
-                DataUsageRequest request, Messenger messenger, IBinder binder) {
+    public DataUsageRequest registerUsageCallback(@NonNull String callingPackage,
+                @NonNull DataUsageRequest request, @NonNull IUsageCallback callback) {
         Objects.requireNonNull(callingPackage, "calling package is null");
         Objects.requireNonNull(request, "DataUsageRequest is null");
         Objects.requireNonNull(request.template, "NetworkTemplate is null");
-        Objects.requireNonNull(messenger, "messenger is null");
-        Objects.requireNonNull(binder, "binder is null");
+        Objects.requireNonNull(callback, "callback is null");
 
         int callingUid = Binder.getCallingUid();
         @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(callingPackage);
         DataUsageRequest normalizedRequest;
         final long token = Binder.clearCallingIdentity();
         try {
-            normalizedRequest = mStatsObservers.register(request, messenger, binder,
-                    callingUid, accessLevel);
+            normalizedRequest = mStatsObservers.register(
+                    request, callback, callingUid, accessLevel);
         } finally {
             Binder.restoreCallingIdentity(token);
         }