Hack and ship: NetworkStats edition.

Some devices use clatd for catching raw IPv4 traffic when running on
a pure-IPv6 carrier network.  In those situations, the per-UID
stats are accounted against the clat iface, so framework users need
to combine both the "base" and "stacked" iface usage together.

This also means that policy rules (like restricting background data
or battery saver) need to apply to the stacked ifaces.

Finally, we need to massage stats data slightly:

-- Currently xt_qtaguid double-counts the clatd traffic *leaving*
the device; both against the original UID on the clat iface, and
against UID 0 on the final egress interface.

-- All clatd traffic *arriving* at the device is missing the extra
IPv6 packet header overhead when accounted against the final UID.

Bug: 12249687, 15459248, 16296564
Change-Id: I0ee59d96831f52782de7a980e4cce9b061902fff
diff --git a/core/java/com/android/internal/net/NetworkStatsFactory.java b/core/java/com/android/internal/net/NetworkStatsFactory.java
index e2a2b1e..506114b 100644
--- a/core/java/com/android/internal/net/NetworkStatsFactory.java
+++ b/core/java/com/android/internal/net/NetworkStatsFactory.java
@@ -25,17 +25,20 @@
 import android.net.NetworkStats;
 import android.os.StrictMode;
 import android.os.SystemClock;
+import android.util.ArrayMap;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.ArrayUtils;
 import com.android.internal.util.ProcFileReader;
 
+import libcore.io.IoUtils;
+
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
 import java.net.ProtocolException;
-
-import libcore.io.IoUtils;
+import java.util.Objects;
 
 /**
  * Creates {@link NetworkStats} instances by parsing various {@code /proc/}
@@ -54,6 +57,19 @@
     /** Path to {@code /proc/net/xt_qtaguid/stats}. */
     private final File mStatsXtUid;
 
+    @GuardedBy("sStackedIfaces")
+    private static final ArrayMap<String, String> sStackedIfaces = new ArrayMap<>();
+
+    public static void noteStackedIface(String stackedIface, String baseIface) {
+        synchronized (sStackedIfaces) {
+            if (baseIface != null) {
+                sStackedIfaces.put(stackedIface, baseIface);
+            } else {
+                sStackedIfaces.remove(stackedIface);
+            }
+        }
+    }
+
     public NetworkStatsFactory() {
         this(new File("/proc/"));
     }
@@ -171,8 +187,54 @@
     }
 
     public NetworkStats readNetworkStatsDetail(int limitUid, String[] limitIfaces, int limitTag,
-            NetworkStats lastStats)
-            throws IOException {
+            NetworkStats lastStats) throws IOException {
+        final NetworkStats stats = readNetworkStatsDetailInternal(limitUid, limitIfaces, limitTag,
+                lastStats);
+
+        synchronized (sStackedIfaces) {
+            // Sigh, xt_qtaguid ends up double-counting tx traffic going through
+            // clatd interfaces, so we need to subtract it here.
+            final int size = sStackedIfaces.size();
+            for (int i = 0; i < size; i++) {
+                final String stackedIface = sStackedIfaces.keyAt(i);
+                final String baseIface = sStackedIfaces.valueAt(i);
+
+                // Count up the tx traffic and subtract from root UID on the
+                // base interface.
+                NetworkStats.Entry adjust = new NetworkStats.Entry(baseIface, 0, 0, 0, 0L, 0L, 0L,
+                        0L, 0L);
+                NetworkStats.Entry entry = null;
+                for (int j = 0; j < stats.size(); j++) {
+                    entry = stats.getValues(j, entry);
+                    if (Objects.equals(entry.iface, stackedIface)) {
+                        adjust.txBytes -= entry.txBytes;
+                        adjust.txPackets -= entry.txPackets;
+                    }
+                }
+                stats.combineValues(adjust);
+            }
+        }
+
+        // Double sigh, all rx traffic on clat needs to be tweaked to
+        // account for the dropped IPv6 header size post-unwrap.
+        NetworkStats.Entry entry = null;
+        for (int i = 0; i < stats.size(); i++) {
+            entry = stats.getValues(i, entry);
+            if (entry.iface != null && entry.iface.startsWith("clat")) {
+                // Delta between IPv4 header (20b) and IPv6 header (40b)
+                entry.rxBytes = entry.rxPackets * 20;
+                entry.rxPackets = 0;
+                entry.txBytes = 0;
+                entry.txPackets = 0;
+                stats.combineValues(entry);
+            }
+        }
+
+        return stats;
+    }
+
+    private NetworkStats readNetworkStatsDetailInternal(int limitUid, String[] limitIfaces,
+            int limitTag, NetworkStats lastStats) throws IOException {
         if (USE_NATIVE_PARSING) {
             final NetworkStats stats;
             if (lastStats != null) {
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index 271e9e9..e35ca46 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -62,8 +62,6 @@
 import static android.text.format.DateUtils.HOUR_IN_MILLIS;
 import static android.text.format.DateUtils.MINUTE_IN_MILLIS;
 import static android.text.format.DateUtils.SECOND_IN_MILLIS;
-import static com.android.internal.util.ArrayUtils.appendElement;
-import static com.android.internal.util.ArrayUtils.contains;
 import static com.android.internal.util.Preconditions.checkNotNull;
 import static com.android.server.NetworkManagementService.LIMIT_GLOBAL_ALERT;
 import static com.android.server.NetworkManagementSocketTagger.resetKernelUidStats;
@@ -107,6 +105,8 @@
 import android.provider.Settings.Global;
 import android.telephony.PhoneStateListener;
 import android.telephony.TelephonyManager;
+import android.util.ArrayMap;
+import android.util.ArraySet;
 import android.util.EventLog;
 import android.util.Log;
 import android.util.MathUtils;
@@ -121,14 +121,12 @@
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.server.EventLogTags;
 import com.android.server.connectivity.Tethering;
-import com.google.android.collect.Maps;
 
 import java.io.File;
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.io.PrintWriter;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 
@@ -215,7 +213,9 @@
     private final Object mStatsLock = new Object();
 
     /** Set of currently active ifaces. */
-    private HashMap<String, NetworkIdentitySet> mActiveIfaces = Maps.newHashMap();
+    private final ArrayMap<String, NetworkIdentitySet> mActiveIfaces = new ArrayMap<>();
+    /** Set of currently active ifaces for UID stats. */
+    private final ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces = new ArrayMap<>();
     /** Current default active iface. */
     private String mActiveIface;
     /** Set of any ifaces associated with mobile networks since boot. */
@@ -883,30 +883,52 @@
 
         mActiveIface = activeLink != null ? activeLink.getInterfaceName() : null;
 
-        // rebuild active interfaces based on connected networks
+        // Rebuild active interfaces based on connected networks
         mActiveIfaces.clear();
+        mActiveUidIfaces.clear();
 
+        final ArraySet<String> mobileIfaces = new ArraySet<>();
         for (NetworkState state : states) {
             if (state.networkInfo.isConnected()) {
-                // collect networks under their parent interfaces
-                final String iface = state.linkProperties.getInterfaceName();
+                final boolean isMobile = isNetworkTypeMobile(state.networkInfo.getType());
+                final NetworkIdentity ident = NetworkIdentity.buildNetworkIdentity(mContext, state);
 
-                NetworkIdentitySet ident = mActiveIfaces.get(iface);
-                if (ident == null) {
-                    ident = new NetworkIdentitySet();
-                    mActiveIfaces.put(iface, ident);
+                // Traffic occurring on the base interface is always counted for
+                // both total usage and UID details.
+                final String baseIface = state.linkProperties.getInterfaceName();
+                findOrCreateNetworkIdentitySet(mActiveIfaces, baseIface).add(ident);
+                findOrCreateNetworkIdentitySet(mActiveUidIfaces, baseIface).add(ident);
+                if (isMobile) {
+                    mobileIfaces.add(baseIface);
                 }
 
-                ident.add(NetworkIdentity.buildNetworkIdentity(mContext, state));
-
-                // remember any ifaces associated with mobile networks
-                if (isNetworkTypeMobile(state.networkInfo.getType()) && iface != null) {
-                    if (!contains(mMobileIfaces, iface)) {
-                        mMobileIfaces = appendElement(String.class, mMobileIfaces, iface);
+                // Traffic occurring on stacked interfaces is usually clatd,
+                // which is already accounted against its final egress interface
+                // by the kernel. Thus, we only need to collect stacked
+                // interface stats at the UID level.
+                final List<LinkProperties> stackedLinks = state.linkProperties.getStackedLinks();
+                for (LinkProperties stackedLink : stackedLinks) {
+                    final String stackedIface = stackedLink.getInterfaceName();
+                    findOrCreateNetworkIdentitySet(mActiveUidIfaces, stackedIface).add(ident);
+                    if (isMobile) {
+                        mobileIfaces.add(stackedIface);
                     }
                 }
             }
         }
+
+        mobileIfaces.remove(null);
+        mMobileIfaces = mobileIfaces.toArray(new String[mobileIfaces.size()]);
+    }
+
+    private static <K> NetworkIdentitySet findOrCreateNetworkIdentitySet(
+            ArrayMap<K, NetworkIdentitySet> map, K key) {
+        NetworkIdentitySet ident = map.get(key);
+        if (ident == null) {
+            ident = new NetworkIdentitySet();
+            map.put(key, ident);
+        }
+        return ident;
     }
 
     /**
@@ -926,8 +948,8 @@
 
             mDevRecorder.recordSnapshotLocked(devSnapshot, mActiveIfaces, currentTime);
             mXtRecorder.recordSnapshotLocked(xtSnapshot, mActiveIfaces, currentTime);
-            mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveIfaces, currentTime);
-            mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveIfaces, currentTime);
+            mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, currentTime);
+            mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, currentTime);
 
         } catch (IllegalStateException e) {
             Slog.w(TAG, "problem reading network stats: " + e);
@@ -980,8 +1002,8 @@
 
             mDevRecorder.recordSnapshotLocked(devSnapshot, mActiveIfaces, currentTime);
             mXtRecorder.recordSnapshotLocked(xtSnapshot, mActiveIfaces, currentTime);
-            mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveIfaces, currentTime);
-            mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveIfaces, currentTime);
+            mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, currentTime);
+            mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, currentTime);
 
         } catch (IllegalStateException e) {
             Log.wtf(TAG, "problem reading network stats", e);
@@ -1136,10 +1158,19 @@
 
             pw.println("Active interfaces:");
             pw.increaseIndent();
-            for (String iface : mActiveIfaces.keySet()) {
-                final NetworkIdentitySet ident = mActiveIfaces.get(iface);
-                pw.print("iface="); pw.print(iface);
-                pw.print(" ident="); pw.println(ident.toString());
+            for (int i = 0; i < mActiveIfaces.size(); i++) {
+                pw.printPair("iface", mActiveIfaces.keyAt(i));
+                pw.printPair("ident", mActiveIfaces.valueAt(i));
+                pw.println();
+            }
+            pw.decreaseIndent();
+
+            pw.println("Active UID interfaces:");
+            pw.increaseIndent();
+            for (int i = 0; i < mActiveUidIfaces.size(); i++) {
+                pw.printPair("iface", mActiveUidIfaces.keyAt(i));
+                pw.printPair("ident", mActiveUidIfaces.valueAt(i));
+                pw.println();
             }
             pw.decreaseIndent();