Improve count of openSession callers in dumpsys

Currently, it's only uid=count in dumpsys, however, the call may also
come from system but there is no such information in dumpsys. If the
call came from system, it's hard to know who is the caller. So, this
patch makes pid, uid and package as a key to record the count of the
caller.

The output will look like,
Top openSession callers:
  {pid=1787,uid=1000,package=android}=26
  {pid=2931,uid=1000,package=com.android.settings}=4

Test: adb shell dumpsys netstats and check the output
Bug: 228081549
Change-Id: Id2dfdd4480aa5b1ece51e46de6efe30a46629811
(cherry picked from commit b87b5e5d0a4f08f58d85ff11d39a64e2c4e4646b)
Merged-In: Id2dfdd4480aa5b1ece51e46de6efe30a46629811
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 82b1fb5..391356c 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -159,8 +159,11 @@
 import java.time.ZoneOffset;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.Executor;
@@ -374,9 +377,9 @@
 
     private long mLastStatsSessionPoll;
 
-    /** Map from UID to number of opened sessions */
-    @GuardedBy("mOpenSessionCallsPerUid")
-    private final SparseIntArray mOpenSessionCallsPerUid = new SparseIntArray();
+    /** Map from key {@code OpenSessionKey} to count of opened sessions */
+    @GuardedBy("mOpenSessionCallsPerCaller")
+    private final HashMap<OpenSessionKey, Integer> mOpenSessionCallsPerCaller = new HashMap<>();
 
     private final static int DUMP_STATS_SESSION_COUNT = 20;
 
@@ -407,6 +410,48 @@
                 Clock.systemUTC());
     }
 
+    /**
+     * This class is a key that used in {@code mOpenSessionCallsPerCaller} to identify the count of
+     * the caller.
+     */
+    private static class OpenSessionKey {
+        public final int mPid;
+        public final int mUid;
+        public final String mPackage;
+
+        OpenSessionKey(int pid, int uid, @NonNull String packageName) {
+            mPid = pid;
+            mUid = uid;
+            mPackage = packageName;
+        }
+
+        @Override
+        public String toString() {
+            final StringBuilder sb = new StringBuilder();
+            sb.append("{");
+            sb.append("pid=").append(mPid).append(",");
+            sb.append("uid=").append(mUid).append(",");
+            sb.append("package=").append(mPackage);
+            sb.append("}");
+            return sb.toString();
+        }
+
+        @Override
+        public boolean equals(@NonNull Object o) {
+            if (this == o) return true;
+            if (o.getClass() != getClass()) return false;
+
+            final OpenSessionKey key = (OpenSessionKey) o;
+            return mPid == key.mPid && mUid == key.mUid
+                    && TextUtils.equals(mPackage, key.mPackage);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(mPid, mUid, mPackage);
+        }
+    }
+
     private final class NetworkStatsHandler extends Handler {
         NetworkStatsHandler(@NonNull Looper looper) {
             super(looper);
@@ -794,24 +839,29 @@
         return openSessionInternal(flags, callingPackage);
     }
 
-    private boolean isRateLimitedForPoll(int callingUid) {
-        if (callingUid == android.os.Process.SYSTEM_UID) {
-            return false;
-        }
-
+    private boolean isRateLimitedForPoll(@NonNull OpenSessionKey key) {
         final long lastCallTime;
         final long now = SystemClock.elapsedRealtime();
-        synchronized (mOpenSessionCallsPerUid) {
-            int calls = mOpenSessionCallsPerUid.get(callingUid, 0);
-            mOpenSessionCallsPerUid.put(callingUid, calls + 1);
+
+        synchronized (mOpenSessionCallsPerCaller) {
+            Integer calls = mOpenSessionCallsPerCaller.get(key);
+            if (calls == null) {
+                mOpenSessionCallsPerCaller.put((key), 1);
+            } else {
+                mOpenSessionCallsPerCaller.put(key, Integer.sum(calls, 1));
+            }
             lastCallTime = mLastStatsSessionPoll;
             mLastStatsSessionPoll = now;
         }
 
+        if (key.mUid == android.os.Process.SYSTEM_UID) {
+            return false;
+        }
+
         return now - lastCallTime < POLL_RATE_LIMIT_MS;
     }
 
-    private int restrictFlagsForCaller(int flags) {
+    private int restrictFlagsForCaller(int flags, @NonNull String callingPackage) {
         // All non-privileged callers are not allowed to turn off POLL_ON_OPEN.
         final boolean isPrivileged = PermissionUtils.checkAnyPermissionOf(mContext,
                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
@@ -820,15 +870,17 @@
             flags |= NetworkStatsManager.FLAG_POLL_ON_OPEN;
         }
         // Non-system uids are rate limited for POLL_ON_OPEN.
+        final int callingPid = Binder.getCallingPid();
         final int callingUid = Binder.getCallingUid();
-        flags = isRateLimitedForPoll(callingUid)
+        final OpenSessionKey key = new OpenSessionKey(callingPid, callingUid, callingPackage);
+        flags = isRateLimitedForPoll(key)
                 ? flags & (~NetworkStatsManager.FLAG_POLL_ON_OPEN)
                 : flags;
         return flags;
     }
 
     private INetworkStatsSession openSessionInternal(final int flags, final String callingPackage) {
-        final int restrictedFlags = restrictFlagsForCaller(flags);
+        final int restrictedFlags = restrictFlagsForCaller(flags, callingPackage);
         if ((restrictedFlags & (NetworkStatsManager.FLAG_POLL_ON_OPEN
                 | NetworkStatsManager.FLAG_POLL_FORCE)) != 0) {
             final long ident = Binder.clearCallingIdentity();
@@ -2060,25 +2112,21 @@
             pw.decreaseIndent();
 
             // Get the top openSession callers
-            final SparseIntArray calls;
-            synchronized (mOpenSessionCallsPerUid) {
-                calls = mOpenSessionCallsPerUid.clone();
+            final HashMap calls;
+            synchronized (mOpenSessionCallsPerCaller) {
+                calls = new HashMap<>(mOpenSessionCallsPerCaller);
             }
-
-            final int N = calls.size();
-            final long[] values = new long[N];
-            for (int j = 0; j < N; j++) {
-                values[j] = ((long) calls.valueAt(j) << 32) | calls.keyAt(j);
-            }
-            Arrays.sort(values);
-
-            pw.println("Top openSession callers (uid=count):");
+            final List<Map.Entry<OpenSessionKey, Integer>> list = new ArrayList<>(calls.entrySet());
+            Collections.sort(list,
+                    (left, right) -> Integer.compare(left.getValue(), right.getValue()));
+            final int num = list.size();
+            final int end = Math.max(0, num - DUMP_STATS_SESSION_COUNT);
+            pw.println("Top openSession callers:");
             pw.increaseIndent();
-            final int end = Math.max(0, N - DUMP_STATS_SESSION_COUNT);
-            for (int j = N - 1; j >= end; j--) {
-                final int uid = (int) (values[j] & 0xffffffff);
-                final int count = (int) (values[j] >> 32);
-                pw.print(uid); pw.print("="); pw.println(count);
+            for (int j = num - 1; j >= end; j--) {
+                final Map.Entry<OpenSessionKey, Integer> entry = list.get(j);
+                pw.print(entry.getKey()); pw.print("="); pw.println(entry.getValue());
+
             }
             pw.decreaseIndent();
             pw.println();