Merge "DO NOT MERGE: Merge Oreo MR1 into master"
diff --git a/core/java/android/app/usage/NetworkStats.java b/core/java/android/app/usage/NetworkStats.java
index 3670b91..222e9a0 100644
--- a/core/java/android/app/usage/NetworkStats.java
+++ b/core/java/android/app/usage/NetworkStats.java
@@ -97,12 +97,12 @@
     private NetworkStatsHistory.Entry mRecycledHistoryEntry = null;
 
     /** @hide */
-    NetworkStats(Context context, NetworkTemplate template, long startTimestamp,
+    NetworkStats(Context context, NetworkTemplate template, int flags, long startTimestamp,
             long endTimestamp) throws RemoteException, SecurityException {
         final INetworkStatsService statsService = INetworkStatsService.Stub.asInterface(
                 ServiceManager.getService(Context.NETWORK_STATS_SERVICE));
         // Open network stats session
-        mSession = statsService.openSessionForUsageStats(context.getOpPackageName());
+        mSession = statsService.openSessionForUsageStats(flags, context.getOpPackageName());
         mCloseGuard.open("close");
         mTemplate = template;
         mStartTimeStamp = startTimestamp;
diff --git a/core/java/android/app/usage/NetworkStatsManager.java b/core/java/android/app/usage/NetworkStatsManager.java
index ef262e0..853b003 100644
--- a/core/java/android/app/usage/NetworkStatsManager.java
+++ b/core/java/android/app/usage/NetworkStatsManager.java
@@ -24,15 +24,14 @@
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.DataUsageRequest;
+import android.net.INetworkStatsService;
 import android.net.NetworkIdentity;
 import android.net.NetworkTemplate;
-import android.net.INetworkStatsService;
 import android.os.Binder;
-import android.os.Build;
-import android.os.Message;
-import android.os.Messenger;
 import android.os.Handler;
 import android.os.Looper;
+import android.os.Message;
+import android.os.Messenger;
 import android.os.RemoteException;
 import android.os.ServiceManager;
 import android.os.ServiceManager.ServiceNotFoundException;
@@ -79,7 +78,7 @@
  * In addition to tethering usage, usage by removed users and apps, and usage by the system
  * is also included in the results for callers with one of these higher levels of access.
  * <p />
- * <b>NOTE:</b> Prior to API level {@value Build.VERSION_CODES#N}, all calls to these APIs required
+ * <b>NOTE:</b> Prior to API level {@value android.os.Build.VERSION_CODES#N}, all calls to these APIs required
  * the above permission, even to access an app's own data usage, and carrier-privileged apps were
  * not included.
  */
@@ -96,6 +95,13 @@
     private final Context mContext;
     private final INetworkStatsService mService;
 
+    /** @hide */
+    public static final int FLAG_POLL_ON_OPEN = 1 << 0;
+    /** @hide */
+    public static final int FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN = 1 << 1;
+
+    private int mFlags;
+
     /**
      * {@hide}
      */
@@ -103,6 +109,25 @@
         mContext = context;
         mService = INetworkStatsService.Stub.asInterface(
                 ServiceManager.getServiceOrThrow(Context.NETWORK_STATS_SERVICE));
+        setPollOnOpen(true);
+    }
+
+    /** @hide */
+    public void setPollOnOpen(boolean pollOnOpen) {
+        if (pollOnOpen) {
+            mFlags |= FLAG_POLL_ON_OPEN;
+        } else {
+            mFlags &= ~FLAG_POLL_ON_OPEN;
+        }
+    }
+
+    /** @hide */
+    public void setAugmentWithSubscriptionPlan(boolean augmentWithSubscriptionPlan) {
+        if (augmentWithSubscriptionPlan) {
+            mFlags |= FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN;
+        } else {
+            mFlags &= ~FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN;
+        }
     }
 
     /**
@@ -136,7 +161,7 @@
         }
 
         Bucket bucket = null;
-        NetworkStats stats = new NetworkStats(mContext, template, startTime, endTime);
+        NetworkStats stats = new NetworkStats(mContext, template, mFlags, startTime, endTime);
         bucket = stats.getDeviceSummaryForNetwork();
 
         stats.close();
@@ -174,7 +199,7 @@
         }
 
         NetworkStats stats;
-        stats = new NetworkStats(mContext, template, startTime, endTime);
+        stats = new NetworkStats(mContext, template, mFlags, startTime, endTime);
         stats.startSummaryEnumeration();
 
         stats.close();
@@ -211,7 +236,7 @@
         }
 
         NetworkStats result;
-        result = new NetworkStats(mContext, template, startTime, endTime);
+        result = new NetworkStats(mContext, template, mFlags, startTime, endTime);
         result.startSummaryEnumeration();
 
         return result;
@@ -260,7 +285,7 @@
 
         NetworkStats result;
         try {
-            result = new NetworkStats(mContext, template, startTime, endTime);
+            result = new NetworkStats(mContext, template, mFlags, startTime, endTime);
             result.startHistoryEnumeration(uid, tag);
         } catch (RemoteException e) {
             Log.e(TAG, "Error while querying stats for uid=" + uid + " tag=" + tag, e);
@@ -305,7 +330,7 @@
         }
 
         NetworkStats result;
-        result = new NetworkStats(mContext, template, startTime, endTime);
+        result = new NetworkStats(mContext, template, mFlags, startTime, endTime);
         result.startUserUidEnumeration();
         return result;
     }
diff --git a/core/java/android/net/INetworkStatsService.aidl b/core/java/android/net/INetworkStatsService.aidl
index e693009..9180112 100644
--- a/core/java/android/net/INetworkStatsService.aidl
+++ b/core/java/android/net/INetworkStatsService.aidl
@@ -36,7 +36,7 @@
      *  PACKAGE_USAGE_STATS permission is always checked. If PACKAGE_USAGE_STATS is not granted
      *  READ_NETWORK_USAGE_STATS is checked for.
      */
-    INetworkStatsSession openSessionForUsageStats(String callingPackage);
+    INetworkStatsSession openSessionForUsageStats(int flags, String callingPackage);
 
     /** Return network layer usage total for traffic that matches template. */
     long getNetworkTotalBytes(in NetworkTemplate template, long start, long end);
diff --git a/core/java/android/net/NetworkIdentity.java b/core/java/android/net/NetworkIdentity.java
index df404b7..d3b3599 100644
--- a/core/java/android/net/NetworkIdentity.java
+++ b/core/java/android/net/NetworkIdentity.java
@@ -157,7 +157,7 @@
      * Scrub given IMSI on production builds.
      */
     public static String scrubSubscriberId(String subscriberId) {
-        if ("eng".equals(Build.TYPE)) {
+        if (Build.IS_ENG) {
             return subscriberId;
         } else if (subscriberId != null) {
             // TODO: parse this as MCC+MNC instead of hard-coding
diff --git a/core/java/android/net/NetworkStatsHistory.java b/core/java/android/net/NetworkStatsHistory.java
index 5f521de..433f941 100644
--- a/core/java/android/net/NetworkStatsHistory.java
+++ b/core/java/android/net/NetworkStatsHistory.java
@@ -27,6 +27,7 @@
 import static android.net.NetworkStatsHistory.ParcelUtils.readLongArray;
 import static android.net.NetworkStatsHistory.ParcelUtils.writeLongArray;
 import static android.text.format.DateUtils.SECOND_IN_MILLIS;
+
 import static com.android.internal.util.ArrayUtils.total;
 
 import android.os.Parcel;
@@ -282,6 +283,24 @@
         return entry;
     }
 
+    public void setValues(int i, Entry entry) {
+        // Unwind old values
+        if (rxBytes != null) totalBytes -= rxBytes[i];
+        if (txBytes != null) totalBytes -= txBytes[i];
+
+        bucketStart[i] = entry.bucketStart;
+        setLong(activeTime, i, entry.activeTime);
+        setLong(rxBytes, i, entry.rxBytes);
+        setLong(rxPackets, i, entry.rxPackets);
+        setLong(txBytes, i, entry.txBytes);
+        setLong(txPackets, i, entry.txPackets);
+        setLong(operations, i, entry.operations);
+
+        // Apply new values
+        if (rxBytes != null) totalBytes += rxBytes[i];
+        if (txBytes != null) totalBytes += txBytes[i];
+    }
+
     /**
      * Record that data traffic occurred in the given time range. Will
      * distribute across internal buckets, creating new buckets as needed.
diff --git a/core/java/android/net/NetworkTemplate.java b/core/java/android/net/NetworkTemplate.java
index 0d2fcd0..b307c5d 100644
--- a/core/java/android/net/NetworkTemplate.java
+++ b/core/java/android/net/NetworkTemplate.java
@@ -326,6 +326,10 @@
         }
     }
 
+    public boolean matchesSubscriberId(String subscriberId) {
+        return ArrayUtils.contains(mMatchSubscriberIds, subscriberId);
+    }
+
     /**
      * Check if mobile network with matching IMSI.
      */
diff --git a/core/java/android/net/TrafficStats.java b/core/java/android/net/TrafficStats.java
index f934616..c339856 100644
--- a/core/java/android/net/TrafficStats.java
+++ b/core/java/android/net/TrafficStats.java
@@ -109,25 +109,26 @@
      */
     public static final int TAG_SYSTEM_RESTORE = 0xFFFFFF04;
 
-    /** @hide */
-    public static final int TAG_SYSTEM_DHCP = 0xFFFFFF05;
-    /** @hide */
-    public static final int TAG_SYSTEM_NTP = 0xFFFFFF06;
-    /** @hide */
-    public static final int TAG_SYSTEM_PROBE = 0xFFFFFF07;
-    /** @hide */
-    public static final int TAG_SYSTEM_NEIGHBOR = 0xFFFFFF08;
-    /** @hide */
-    public static final int TAG_SYSTEM_GPS = 0xFFFFFF09;
-    /** @hide */
-    public static final int TAG_SYSTEM_PAC = 0xFFFFFF0A;
-
     /**
-     * Sockets that are strictly local on device; never hits network.
+     * Default tag value for code (typically APKs) downloaded by an app store on
+     * behalf of the app, such as updates.
      *
      * @hide
      */
-    public static final int TAG_SYSTEM_LOCAL = 0xFFFFFFAA;
+    public static final int TAG_SYSTEM_APP = 0xFFFFFF05;
+
+    /** @hide */
+    public static final int TAG_SYSTEM_DHCP = 0xFFFFFF40;
+    /** @hide */
+    public static final int TAG_SYSTEM_NTP = 0xFFFFFF41;
+    /** @hide */
+    public static final int TAG_SYSTEM_PROBE = 0xFFFFFF42;
+    /** @hide */
+    public static final int TAG_SYSTEM_NEIGHBOR = 0xFFFFFF43;
+    /** @hide */
+    public static final int TAG_SYSTEM_GPS = 0xFFFFFF44;
+    /** @hide */
+    public static final int TAG_SYSTEM_PAC = 0xFFFFFF45;
 
     private static INetworkStatsService sStatsService;
 
@@ -210,6 +211,19 @@
     }
 
     /**
+     * Set active tag to use when accounting {@link Socket} traffic originating
+     * from the current thread. The tag used internally is well-defined to
+     * distinguish all code (typically APKs) downloaded by an app store on
+     * behalf of the app, such as updates.
+     *
+     * @hide
+     */
+    @SystemApi
+    public static void setThreadStatsTagApp() {
+        setThreadStatsTag(TAG_SYSTEM_APP);
+    }
+
+    /**
      * Get the active tag used when accounting {@link Socket} traffic originating
      * from the current thread. Only one active tag per thread is supported.
      * {@link #tagSocket(Socket)}.
diff --git a/services/core/java/com/android/server/net/NetworkStatsCollection.java b/services/core/java/com/android/server/net/NetworkStatsCollection.java
index 0354300..4ceb592 100644
--- a/services/core/java/com/android/server/net/NetworkStatsCollection.java
+++ b/services/core/java/com/android/server/net/NetworkStatsCollection.java
@@ -28,6 +28,8 @@
 import static android.net.TrafficStats.UID_REMOVED;
 import static android.text.format.DateUtils.WEEK_IN_MILLIS;
 
+import static com.android.server.net.NetworkStatsService.TAG;
+
 import android.net.NetworkIdentity;
 import android.net.NetworkStats;
 import android.net.NetworkStatsHistory;
@@ -37,20 +39,24 @@
 import android.service.NetworkStatsCollectionKeyProto;
 import android.service.NetworkStatsCollectionProto;
 import android.service.NetworkStatsCollectionStatsProto;
+import android.telephony.SubscriptionPlan;
 import android.util.ArrayMap;
 import android.util.AtomicFile;
 import android.util.IntArray;
+import android.util.Pair;
+import android.util.Slog;
 import android.util.proto.ProtoOutputStream;
 
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.ArrayUtils;
 import com.android.internal.util.FileRotator;
 import com.android.internal.util.IndentingPrintWriter;
 
+import libcore.io.IoUtils;
+
 import com.google.android.collect.Lists;
 import com.google.android.collect.Maps;
 
-import libcore.io.IoUtils;
-
 import java.io.BufferedInputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
@@ -60,9 +66,11 @@
 import java.io.InputStream;
 import java.io.PrintWriter;
 import java.net.ProtocolException;
+import java.time.ZonedDateTime;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.Objects;
 
 /**
@@ -140,6 +148,63 @@
         return mStartMillis == Long.MAX_VALUE && mEndMillis == Long.MIN_VALUE;
     }
 
+    @VisibleForTesting
+    public long roundUp(long time) {
+        if (time == Long.MIN_VALUE || time == Long.MAX_VALUE
+                || time == SubscriptionPlan.TIME_UNKNOWN) {
+            return time;
+        } else {
+            final long mod = time % mBucketDuration;
+            if (mod > 0) {
+                time -= mod;
+                time += mBucketDuration;
+            }
+            return time;
+        }
+    }
+
+    @VisibleForTesting
+    public long roundDown(long time) {
+        if (time == Long.MIN_VALUE || time == Long.MAX_VALUE
+                || time == SubscriptionPlan.TIME_UNKNOWN) {
+            return time;
+        } else {
+            final long mod = time % mBucketDuration;
+            if (mod > 0) {
+                time -= mod;
+            }
+            return time;
+        }
+    }
+
+    /**
+     * Safely multiple a value by a rational.
+     * <p>
+     * Internally it uses integer-based math whenever possible, but switches
+     * over to double-based math if values would overflow.
+     */
+    @VisibleForTesting
+    public static long multiplySafe(long value, long num, long den) {
+        long x = value;
+        long y = num;
+
+        // Logic shamelessly borrowed from Math.multiplyExact()
+        long r = x * y;
+        long ax = Math.abs(x);
+        long ay = Math.abs(y);
+        if (((ax | ay) >>> 31 != 0)) {
+            // Some bits greater than 2^31 that might cause overflow
+            // Check the result using the divide operator
+            // and check for the special case of Long.MIN_VALUE * -1
+            if (((y != 0) && (r / y != x)) ||
+                    (x == Long.MIN_VALUE && y == -1)) {
+                // Use double math to avoid overflowing
+                return (long) (((double) num / den) * value);
+            }
+        }
+        return r / den;
+    }
+
     public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel) {
         return getRelevantUids(accessLevel, Binder.getCallingUid());
     }
@@ -165,60 +230,110 @@
      * Combine all {@link NetworkStatsHistory} in this collection which match
      * the requested parameters.
      */
-    public NetworkStatsHistory getHistory(
-            NetworkTemplate template, int uid, int set, int tag, int fields,
-            @NetworkStatsAccess.Level int accessLevel) {
-        return getHistory(template, uid, set, tag, fields, Long.MIN_VALUE, Long.MAX_VALUE,
-                accessLevel);
-    }
-
-    /**
-     * Combine all {@link NetworkStatsHistory} in this collection which match
-     * the requested parameters.
-     */
-    public NetworkStatsHistory getHistory(
-            NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end,
-            @NetworkStatsAccess.Level int accessLevel) {
-        return getHistory(template, uid, set, tag, fields, start, end, accessLevel,
-                Binder.getCallingUid());
-    }
-
-    /**
-     * Combine all {@link NetworkStatsHistory} in this collection which match
-     * the requested parameters.
-     */
-    public NetworkStatsHistory getHistory(
-            NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end,
+    public NetworkStatsHistory getHistory(NetworkTemplate template, SubscriptionPlan augmentPlan,
+            int uid, int set, int tag, int fields, long start, long end,
             @NetworkStatsAccess.Level int accessLevel, int callerUid) {
         if (!NetworkStatsAccess.isAccessibleToUser(uid, callerUid, accessLevel)) {
             throw new SecurityException("Network stats history of uid " + uid
                     + " is forbidden for caller " + callerUid);
         }
 
+        final int bucketEstimate = (int) ((end - start) / mBucketDuration);
         final NetworkStatsHistory combined = new NetworkStatsHistory(
-                mBucketDuration, start == end ? 1 : estimateBuckets(), fields);
+                mBucketDuration, bucketEstimate, fields);
 
         // shortcut when we know stats will be empty
         if (start == end) return combined;
 
+        // Figure out the window of time that we should be augmenting (if any)
+        long augmentStart = SubscriptionPlan.TIME_UNKNOWN;
+        long augmentEnd = (augmentPlan != null) ? augmentPlan.getDataUsageTime()
+                : SubscriptionPlan.TIME_UNKNOWN;
+        // And if augmenting, we might need to collect more data to adjust with
+        long collectStart = start;
+        long collectEnd = end;
+
+        if (augmentEnd != SubscriptionPlan.TIME_UNKNOWN) {
+            final Iterator<Pair<ZonedDateTime, ZonedDateTime>> it = augmentPlan.cycleIterator();
+            while (it.hasNext()) {
+                final Pair<ZonedDateTime, ZonedDateTime> cycle = it.next();
+                final long cycleStart = cycle.first.toInstant().toEpochMilli();
+                final long cycleEnd = cycle.second.toInstant().toEpochMilli();
+                if (cycleStart <= augmentEnd && augmentEnd < cycleEnd) {
+                    augmentStart = cycleStart;
+                    collectStart = Long.min(collectStart, augmentStart);
+                    collectEnd = Long.max(collectEnd, augmentEnd);
+                    break;
+                }
+            }
+        }
+
+        if (augmentStart != SubscriptionPlan.TIME_UNKNOWN) {
+            // Shrink augmentation window so we don't risk undercounting.
+            augmentStart = roundUp(augmentStart);
+            augmentEnd = roundDown(augmentEnd);
+            // Grow collection window so we get all the stats needed.
+            collectStart = roundDown(collectStart);
+            collectEnd = roundUp(collectEnd);
+        }
+
         for (int i = 0; i < mStats.size(); i++) {
             final Key key = mStats.keyAt(i);
             if (key.uid == uid && NetworkStats.setMatches(set, key.set) && key.tag == tag
                     && templateMatches(template, key.ident)) {
                 final NetworkStatsHistory value = mStats.valueAt(i);
-                combined.recordHistory(value, start, end);
+                combined.recordHistory(value, collectStart, collectEnd);
             }
         }
-        return combined;
-    }
 
-    /**
-     * Summarize all {@link NetworkStatsHistory} in this collection which match
-     * the requested parameters.
-     */
-    public NetworkStats getSummary(NetworkTemplate template, long start, long end,
-            @NetworkStatsAccess.Level int accessLevel) {
-        return getSummary(template, start, end, accessLevel, Binder.getCallingUid());
+        if (augmentStart != SubscriptionPlan.TIME_UNKNOWN) {
+            final NetworkStatsHistory.Entry entry = combined.getValues(
+                    augmentStart, augmentEnd, null);
+
+            // If we don't have any recorded data for this time period, give
+            // ourselves something to scale with.
+            if (entry.rxBytes == 0 || entry.txBytes == 0) {
+                combined.recordData(augmentStart, augmentEnd,
+                        new NetworkStats.Entry(1, 0, 1, 0, 0));
+                combined.getValues(augmentStart, augmentEnd, entry);
+            }
+
+            final long rawBytes = entry.rxBytes + entry.txBytes;
+            final long rawRxBytes = entry.rxBytes;
+            final long rawTxBytes = entry.txBytes;
+            final long targetBytes = augmentPlan.getDataUsageBytes();
+            final long targetRxBytes = multiplySafe(targetBytes, rawRxBytes, rawBytes);
+            final long targetTxBytes = multiplySafe(targetBytes, rawTxBytes, rawBytes);
+
+            // Scale all matching buckets to reach anchor target
+            final long beforeTotal = combined.getTotalBytes();
+            for (int i = 0; i < combined.size(); i++) {
+                combined.getValues(i, entry);
+                if (entry.bucketStart >= augmentStart
+                        && entry.bucketStart + entry.bucketDuration <= augmentEnd) {
+                    entry.rxBytes = multiplySafe(targetRxBytes, entry.rxBytes, rawRxBytes);
+                    entry.txBytes = multiplySafe(targetTxBytes, entry.txBytes, rawTxBytes);
+                    // We purposefully clear out packet counters to indicate
+                    // that this data has been augmented.
+                    entry.rxPackets = 0;
+                    entry.txPackets = 0;
+                    combined.setValues(i, entry);
+                }
+            }
+
+            final long deltaTotal = combined.getTotalBytes() - beforeTotal;
+            if (deltaTotal != 0) {
+                Slog.d(TAG, "Augmented network usage by " + deltaTotal + " bytes");
+            }
+
+            // Finally we can slice data as originally requested
+            final NetworkStatsHistory sliced = new NetworkStatsHistory(
+                    mBucketDuration, bucketEstimate, fields);
+            sliced.recordHistory(combined, start, end);
+            return sliced;
+        } else {
+            return combined;
+        }
     }
 
     /**
@@ -230,6 +345,7 @@
         final long now = System.currentTimeMillis();
 
         final NetworkStats stats = new NetworkStats(end - start, 24);
+
         // shortcut when we know stats will be empty
         if (start == end) return stats;
 
diff --git a/services/core/java/com/android/server/net/NetworkStatsObservers.java b/services/core/java/com/android/server/net/NetworkStatsObservers.java
index a256cbc..741c206 100644
--- a/services/core/java/com/android/server/net/NetworkStatsObservers.java
+++ b/services/core/java/com/android/server/net/NetworkStatsObservers.java
@@ -17,28 +17,26 @@
 package com.android.server.net;
 
 import static android.net.TrafficStats.MB_IN_BYTES;
+
 import static com.android.internal.util.Preconditions.checkArgument;
 
 import android.app.usage.NetworkStatsManager;
 import android.net.DataUsageRequest;
 import android.net.NetworkStats;
-import android.net.NetworkStats.NonMonotonicObserver;
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
-import android.os.Binder;
 import android.os.Bundle;
-import android.os.Looper;
-import android.os.Message;
-import android.os.Messenger;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
+import android.os.Looper;
+import android.os.Message;
+import android.os.Messenger;
 import android.os.Process;
 import android.os.RemoteException;
 import android.util.ArrayMap;
-import android.util.IntArray;
-import android.util.SparseArray;
 import android.util.Slog;
+import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.net.VpnInfo;
@@ -410,7 +408,7 @@
          */
         private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
             try {
-                NetworkStatsHistory history = mCollection.getHistory(template, uid,
+                NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
                         NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
                         NetworkStatsHistory.FIELD_ALL,
                         Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
diff --git a/services/core/java/com/android/server/net/NetworkStatsRecorder.java b/services/core/java/com/android/server/net/NetworkStatsRecorder.java
index 80309e1..4bee55e 100644
--- a/services/core/java/com/android/server/net/NetworkStatsRecorder.java
+++ b/services/core/java/com/android/server/net/NetworkStatsRecorder.java
@@ -20,6 +20,7 @@
 import static android.net.TrafficStats.KB_IN_BYTES;
 import static android.net.TrafficStats.MB_IN_BYTES;
 import static android.text.format.DateUtils.YEAR_IN_MILLIS;
+
 import static com.android.internal.util.Preconditions.checkNotNull;
 
 import android.annotation.Nullable;
@@ -28,6 +29,7 @@
 import android.net.NetworkStatsHistory;
 import android.net.NetworkTemplate;
 import android.net.TrafficStats;
+import android.os.Binder;
 import android.os.DropBoxManager;
 import android.service.NetworkStatsRecorderProto;
 import android.util.Log;
@@ -38,6 +40,9 @@
 import com.android.internal.net.VpnInfo;
 import com.android.internal.util.FileRotator;
 import com.android.internal.util.IndentingPrintWriter;
+
+import libcore.io.IoUtils;
+
 import com.google.android.collect.Sets;
 
 import java.io.ByteArrayOutputStream;
@@ -52,8 +57,6 @@
 import java.util.HashSet;
 import java.util.Map;
 
-import libcore.io.IoUtils;
-
 /**
  * Logic to record deltas between periodic {@link NetworkStats} snapshots into
  * {@link NetworkStatsHistory} that belong to {@link NetworkStatsCollection}.
@@ -150,7 +153,7 @@
 
     public NetworkStats.Entry getTotalSinceBootLocked(NetworkTemplate template) {
         return mSinceBoot.getSummary(template, Long.MIN_VALUE, Long.MAX_VALUE,
-                NetworkStatsAccess.Level.DEVICE).getTotal(null);
+                NetworkStatsAccess.Level.DEVICE, Binder.getCallingUid()).getTotal(null);
     }
 
     public NetworkStatsCollection getSinceBoot() {
diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java
index 421db40..3af5265 100644
--- a/services/core/java/com/android/server/net/NetworkStatsService.java
+++ b/services/core/java/com/android/server/net/NetworkStatsService.java
@@ -18,7 +18,6 @@
 
 import static android.Manifest.permission.ACCESS_NETWORK_STATE;
 import static android.Manifest.permission.CONNECTIVITY_INTERNAL;
-import static android.Manifest.permission.MODIFY_NETWORK_ACCOUNTING;
 import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY;
 import static android.content.Intent.ACTION_SHUTDOWN;
 import static android.content.Intent.ACTION_UID_REMOVED;
@@ -27,6 +26,8 @@
 import static android.net.ConnectivityManager.ACTION_TETHER_STATE_CHANGED;
 import static android.net.ConnectivityManager.isNetworkTypeMobile;
 import static android.net.NetworkStats.IFACE_ALL;
+import static android.net.NetworkStats.METERED_ALL;
+import static android.net.NetworkStats.ROAMING_ALL;
 import static android.net.NetworkStats.SET_ALL;
 import static android.net.NetworkStats.SET_DEFAULT;
 import static android.net.NetworkStats.SET_FOREGROUND;
@@ -34,10 +35,12 @@
 import static android.net.NetworkStats.STATS_PER_UID;
 import static android.net.NetworkStats.TAG_NONE;
 import static android.net.NetworkStats.UID_ALL;
+import static android.net.NetworkStatsHistory.FIELD_ALL;
 import static android.net.NetworkTemplate.buildTemplateMobileWildcard;
 import static android.net.NetworkTemplate.buildTemplateWifiWildcard;
 import static android.net.TrafficStats.KB_IN_BYTES;
 import static android.net.TrafficStats.MB_IN_BYTES;
+import static android.provider.Settings.Global.NETSTATS_AUGMENT_ENABLED;
 import static android.provider.Settings.Global.NETSTATS_DEV_BUCKET_DURATION;
 import static android.provider.Settings.Global.NETSTATS_DEV_DELETE_AGE;
 import static android.provider.Settings.Global.NETSTATS_DEV_PERSIST_BYTES;
@@ -66,6 +69,7 @@
 
 import android.app.AlarmManager;
 import android.app.PendingIntent;
+import android.app.usage.NetworkStatsManager;
 import android.content.BroadcastReceiver;
 import android.content.ContentResolver;
 import android.content.Context;
@@ -105,6 +109,8 @@
 import android.provider.Settings.Global;
 import android.service.NetworkInterfaceProto;
 import android.service.NetworkStatsServiceDumpProto;
+import android.telephony.SubscriptionManager;
+import android.telephony.SubscriptionPlan;
 import android.telephony.TelephonyManager;
 import android.text.format.DateUtils;
 import android.util.ArrayMap;
@@ -118,6 +124,7 @@
 import android.util.TrustedTime;
 import android.util.proto.ProtoOutputStream;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.net.VpnInfo;
 import com.android.internal.util.ArrayUtils;
@@ -140,8 +147,8 @@
  * other system services.
  */
 public class NetworkStatsService extends INetworkStatsService.Stub {
-    private static final String TAG = "NetworkStats";
-    private static final boolean LOGV = false;
+    static final String TAG = "NetworkStats";
+    static final boolean LOGV = false;
 
     private static final int MSG_PERFORM_POLL = 1;
     private static final int MSG_UPDATE_IFACES = 2;
@@ -195,6 +202,7 @@
         public long getPollInterval();
         public long getTimeCacheMaxAge();
         public boolean getSampleEnabled();
+        public boolean getAugmentEnabled();
 
         public static class Config {
             public final long bucketDuration;
@@ -234,12 +242,17 @@
     private final DropBoxNonMonotonicObserver mNonMonotonicObserver =
             new DropBoxNonMonotonicObserver();
 
+    @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mDevRecorder;
+    @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mXtRecorder;
+    @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mUidRecorder;
+    @GuardedBy("mStatsLock")
     private NetworkStatsRecorder mUidTagRecorder;
 
     /** Cached {@link #mXtRecorder} stats. */
+    @GuardedBy("mStatsLock")
     private NetworkStatsCollection mXtStatsCached;
 
     /** Current counter sets for each UID. */
@@ -321,15 +334,15 @@
             return;
         }
 
-        // create data recorders along with historical rotators
-        mDevRecorder = buildRecorder(PREFIX_DEV, mSettings.getDevConfig(), false);
-        mXtRecorder = buildRecorder(PREFIX_XT, mSettings.getXtConfig(), false);
-        mUidRecorder = buildRecorder(PREFIX_UID, mSettings.getUidConfig(), false);
-        mUidTagRecorder = buildRecorder(PREFIX_UID_TAG, mSettings.getUidTagConfig(), true);
-
-        updatePersistThresholds();
-
         synchronized (mStatsLock) {
+            // create data recorders along with historical rotators
+            mDevRecorder = buildRecorder(PREFIX_DEV, mSettings.getDevConfig(), false);
+            mXtRecorder = buildRecorder(PREFIX_XT, mSettings.getXtConfig(), false);
+            mUidRecorder = buildRecorder(PREFIX_UID, mSettings.getUidConfig(), false);
+            mUidTagRecorder = buildRecorder(PREFIX_UID_TAG, mSettings.getUidTagConfig(), true);
+
+            updatePersistThresholdsLocked();
+
             // upgrade any legacy stats, migrating them to rotated files
             maybeUpgradeLegacyStatsLocked();
 
@@ -467,18 +480,20 @@
 
     @Override
     public INetworkStatsSession openSession() {
-        return createSession(null, /* poll on create */ false);
+        // NOTE: if callers want to get non-augmented data, they should go
+        // through the public API
+        return openSessionInternal(NetworkStatsManager.FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN, null);
     }
 
     @Override
-    public INetworkStatsSession openSessionForUsageStats(final String callingPackage) {
-        return createSession(callingPackage, /* poll on create */ true);
+    public INetworkStatsSession openSessionForUsageStats(int flags, String callingPackage) {
+        return openSessionInternal(flags, callingPackage);
     }
 
-    private INetworkStatsSession createSession(final String callingPackage, boolean pollOnCreate) {
+    private INetworkStatsSession openSessionInternal(final int flags, final String callingPackage) {
         assertBandwidthControlEnabled();
 
-        if (pollOnCreate) {
+        if ((flags & NetworkStatsManager.FLAG_POLL_ON_OPEN) != 0) {
             final long ident = Binder.clearCallingIdentity();
             try {
                 performPoll(FLAG_PERSIST_ALL);
@@ -491,9 +506,13 @@
         // for its lifetime; when caller closes only weak references remain.
 
         return new INetworkStatsSession.Stub() {
+            private final int mCallingUid = Binder.getCallingUid();
+            private final String mCallingPackage = callingPackage;
+            private final @NetworkStatsAccess.Level int mAccessLevel = checkAccessLevel(
+                    callingPackage);
+
             private NetworkStatsCollection mUidComplete;
             private NetworkStatsCollection mUidTagComplete;
-            private String mCallingPackage = callingPackage;
 
             private NetworkStatsCollection getUidComplete() {
                 synchronized (mStatsLock) {
@@ -515,55 +534,38 @@
 
             @Override
             public int[] getRelevantUids() {
-                return getUidComplete().getRelevantUids(checkAccessLevel(mCallingPackage));
+                return getUidComplete().getRelevantUids(mAccessLevel);
             }
 
             @Override
-            public NetworkStats getDeviceSummaryForNetwork(NetworkTemplate template, long start,
-                    long end) {
-                @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
-                if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) {
-                    throw new SecurityException("Calling package " + mCallingPackage
-                            + " cannot access device summary network stats");
-                }
-                NetworkStats result = new NetworkStats(end - start, 1);
-                final long ident = Binder.clearCallingIdentity();
-                try {
-                    // Using access level higher than the one we checked for above.
-                    // Reason is that we are combining usage data in a way that is not PII
-                    // anymore.
-                    result.combineAllValues(
-                            internalGetSummaryForNetwork(template, start, end,
-                                    NetworkStatsAccess.Level.DEVICE));
-                } finally {
-                    Binder.restoreCallingIdentity(ident);
-                }
-                return result;
+            public NetworkStats getDeviceSummaryForNetwork(
+                    NetworkTemplate template, long start, long end) {
+                return internalGetSummaryForNetwork(template, flags, start, end, mAccessLevel,
+                        mCallingUid);
             }
 
             @Override
             public NetworkStats getSummaryForNetwork(
                     NetworkTemplate template, long start, long end) {
-                @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
-                return internalGetSummaryForNetwork(template, start, end, accessLevel);
+                return internalGetSummaryForNetwork(template, flags, start, end, mAccessLevel,
+                        mCallingUid);
             }
 
             @Override
             public NetworkStatsHistory getHistoryForNetwork(NetworkTemplate template, int fields) {
-                @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
-                return internalGetHistoryForNetwork(template, fields, accessLevel);
+                return internalGetHistoryForNetwork(template, flags, fields, mAccessLevel,
+                        mCallingUid);
             }
 
             @Override
             public NetworkStats getSummaryForAllUid(
                     NetworkTemplate template, long start, long end, boolean includeTags) {
                 try {
-                    @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
-                    final NetworkStats stats =
-                            getUidComplete().getSummary(template, start, end, accessLevel);
+                    final NetworkStats stats = getUidComplete()
+                            .getSummary(template, start, end, mAccessLevel, mCallingUid);
                     if (includeTags) {
                         final NetworkStats tagStats = getUidTagComplete()
-                                .getSummary(template, start, end, accessLevel);
+                                .getSummary(template, start, end, mAccessLevel, mCallingUid);
                         stats.combineAllValues(tagStats);
                     }
                     return stats;
@@ -577,13 +579,13 @@
             @Override
             public NetworkStatsHistory getHistoryForUid(
                     NetworkTemplate template, int uid, int set, int tag, int fields) {
-                @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
+                // NOTE: We don't augment UID-level statistics
                 if (tag == TAG_NONE) {
-                    return getUidComplete().getHistory(template, uid, set, tag, fields,
-                            accessLevel);
+                    return getUidComplete().getHistory(template, null, uid, set, tag, fields,
+                            Long.MIN_VALUE, Long.MAX_VALUE, mAccessLevel, mCallingUid);
                 } else {
-                    return getUidTagComplete().getHistory(template, uid, set, tag, fields,
-                            accessLevel);
+                    return getUidTagComplete().getHistory(template, null, uid, set, tag, fields,
+                            Long.MIN_VALUE, Long.MAX_VALUE, mAccessLevel, mCallingUid);
                 }
             }
 
@@ -591,13 +593,13 @@
             public NetworkStatsHistory getHistoryIntervalForUid(
                     NetworkTemplate template, int uid, int set, int tag, int fields,
                     long start, long end) {
-                @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(mCallingPackage);
+                // NOTE: We don't augment UID-level statistics
                 if (tag == TAG_NONE) {
-                    return getUidComplete().getHistory(template, uid, set, tag, fields, start, end,
-                            accessLevel);
+                    return getUidComplete().getHistory(template, null, uid, set, tag, fields,
+                            start, end, mAccessLevel, mCallingUid);
                 } else if (uid == Binder.getCallingUid()) {
-                    return getUidTagComplete().getHistory(template, uid, set, tag, fields,
-                            start, end, accessLevel);
+                    return getUidTagComplete().getHistory(template, null, uid, set, tag, fields,
+                            start, end, mAccessLevel, mCallingUid);
                 } else {
                     throw new SecurityException("Calling package " + mCallingPackage
                             + " cannot access tag information from a different uid");
@@ -618,36 +620,87 @@
     }
 
     /**
+     * Find the most relevant {@link SubscriptionPlan} for the given
+     * {@link NetworkTemplate} and flags. This is typically used to augment
+     * local measurement results to match a known anchor from the carrier.
+     */
+    private SubscriptionPlan resolveSubscriptionPlan(NetworkTemplate template, int flags) {
+        SubscriptionPlan plan = null;
+        if ((flags & NetworkStatsManager.FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN) != 0
+                && (template.getMatchRule() == NetworkTemplate.MATCH_MOBILE_ALL)
+                && mSettings.getAugmentEnabled()) {
+            Slog.d(TAG, "Resolving plan for " + template);
+            final long token = Binder.clearCallingIdentity();
+            try {
+                final SubscriptionManager sm = mContext.getSystemService(SubscriptionManager.class);
+                final TelephonyManager tm = mContext.getSystemService(TelephonyManager.class);
+                for (int subId : sm.getActiveSubscriptionIdList()) {
+                    if (template.matchesSubscriberId(tm.getSubscriberId(subId))) {
+                        Slog.d(TAG, "Found active matching subId " + subId);
+                        final List<SubscriptionPlan> plans = sm.getSubscriptionPlans(subId);
+                        if (!plans.isEmpty()) {
+                            plan = plans.get(0);
+                        }
+                    }
+                }
+            } finally {
+                Binder.restoreCallingIdentity(token);
+            }
+            Slog.d(TAG, "Resolved to plan " + plan);
+        }
+        return plan;
+    }
+
+    /**
      * Return network summary, splicing between DEV and XT stats when
      * appropriate.
      */
-    private NetworkStats internalGetSummaryForNetwork(
-            NetworkTemplate template, long start, long end,
-            @NetworkStatsAccess.Level int accessLevel) {
+    private NetworkStats internalGetSummaryForNetwork(NetworkTemplate template, int flags,
+            long start, long end, @NetworkStatsAccess.Level int accessLevel, int callingUid) {
         // We've been using pure XT stats long enough that we no longer need to
         // splice DEV and XT together.
-        return mXtStatsCached.getSummary(template, start, end, accessLevel);
+        final NetworkStatsHistory history = internalGetHistoryForNetwork(template, flags, FIELD_ALL,
+                accessLevel, callingUid);
+
+        final long now = System.currentTimeMillis();
+        final NetworkStatsHistory.Entry entry = history.getValues(start, end, now, null);
+
+        final NetworkStats stats = new NetworkStats(end - start, 1);
+        stats.addValues(new NetworkStats.Entry(IFACE_ALL, UID_ALL, SET_ALL, TAG_NONE, METERED_ALL,
+                ROAMING_ALL, entry.rxBytes, entry.rxPackets, entry.txBytes, entry.txPackets,
+                entry.operations));
+        return stats;
     }
 
     /**
      * Return network history, splicing between DEV and XT stats when
      * appropriate.
      */
-    private NetworkStatsHistory internalGetHistoryForNetwork(NetworkTemplate template, int fields,
-            @NetworkStatsAccess.Level int accessLevel) {
+    private NetworkStatsHistory internalGetHistoryForNetwork(NetworkTemplate template,
+            int flags, int fields, @NetworkStatsAccess.Level int accessLevel, int callingUid) {
         // We've been using pure XT stats long enough that we no longer need to
         // splice DEV and XT together.
-        return mXtStatsCached.getHistory(template, UID_ALL, SET_ALL, TAG_NONE, fields, accessLevel);
+        final SubscriptionPlan augmentPlan = resolveSubscriptionPlan(template, flags);
+        synchronized (mStatsLock) {
+            return mXtStatsCached.getHistory(template, augmentPlan,
+                    UID_ALL, SET_ALL, TAG_NONE, fields, Long.MIN_VALUE, Long.MAX_VALUE,
+                    accessLevel, callingUid);
+        }
     }
 
     @Override
     public long getNetworkTotalBytes(NetworkTemplate template, long start, long end) {
-        // Special case - since this is for internal use only, don't worry about a full access level
-        // check and just require the signature/privileged permission.
+        // Special case - since this is for internal use only, don't worry about
+        // a full access level check and just require the signature/privileged
+        // permission.
         mContext.enforceCallingOrSelfPermission(READ_NETWORK_USAGE_HISTORY, TAG);
         assertBandwidthControlEnabled();
-        return internalGetSummaryForNetwork(template, start, end, NetworkStatsAccess.Level.DEVICE)
-                .getTotalBytes();
+
+        // NOTE: if callers want to get non-augmented data, they should go
+        // through the public API
+        return internalGetSummaryForNetwork(template,
+                NetworkStatsManager.FLAG_AUGMENT_WITH_SUBSCRIPTION_PLAN, start, end,
+                NetworkStatsAccess.Level.DEVICE, Binder.getCallingUid()).getTotalBytes();
     }
 
     @Override
@@ -691,7 +744,8 @@
     @Override
     public void incrementOperationCount(int uid, int tag, int operationCount) {
         if (Binder.getCallingUid() != uid) {
-            mContext.enforceCallingOrSelfPermission(MODIFY_NETWORK_ACCOUNTING, TAG);
+            mContext.enforceCallingOrSelfPermission(
+                    android.Manifest.permission.UPDATE_DEVICE_STATS, TAG);
         }
 
         if (operationCount < 0) {
@@ -712,7 +766,7 @@
 
     @Override
     public void setUidForeground(int uid, boolean uidForeground) {
-        mContext.enforceCallingOrSelfPermission(MODIFY_NETWORK_ACCOUNTING, TAG);
+        mContext.enforceCallingOrSelfPermission(CONNECTIVITY_INTERNAL, TAG);
 
         synchronized (mStatsLock) {
             final int set = uidForeground ? SET_FOREGROUND : SET_DEFAULT;
@@ -752,7 +806,7 @@
 
     @Override
     public void advisePersistThreshold(long thresholdBytes) {
-        mContext.enforceCallingOrSelfPermission(MODIFY_NETWORK_ACCOUNTING, TAG);
+        mContext.enforceCallingOrSelfPermission(CONNECTIVITY_INTERNAL, TAG);
         assertBandwidthControlEnabled();
 
         // clamp threshold into safe range
@@ -768,7 +822,7 @@
         synchronized (mStatsLock) {
             if (!mSystemReady) return;
 
-            updatePersistThresholds();
+            updatePersistThresholdsLocked();
 
             mDevRecorder.maybePersistLocked(currentTime);
             mXtRecorder.maybePersistLocked(currentTime);
@@ -824,7 +878,7 @@
      * reflect current {@link #mPersistThreshold} value. Always defers to
      * {@link Global} values when defined.
      */
-    private void updatePersistThresholds() {
+    private void updatePersistThresholdsLocked() {
         mDevRecorder.setPersistThreshold(mSettings.getDevPersistBytes(mPersistThreshold));
         mXtRecorder.setPersistThreshold(mSettings.getXtPersistBytes(mPersistThreshold));
         mUidRecorder.setPersistThreshold(mSettings.getUidPersistBytes(mPersistThreshold));
@@ -1266,7 +1320,7 @@
         synchronized (mStatsLock) {
             if (args.length > 0 && "--proto".equals(args[0])) {
                 // In this case ignore all other arguments.
-                dumpProto(fd);
+                dumpProtoLocked(fd);
                 return;
             }
 
@@ -1342,7 +1396,7 @@
         }
     }
 
-    private void dumpProto(FileDescriptor fd) {
+    private void dumpProtoLocked(FileDescriptor fd) {
         final ProtoOutputStream proto = new ProtoOutputStream(fd);
 
         // TODO Right now it writes all history.  Should it limit to the "since-boot" log?
@@ -1530,6 +1584,10 @@
             return getGlobalBoolean(NETSTATS_SAMPLE_ENABLED, true);
         }
         @Override
+        public boolean getAugmentEnabled() {
+            return getGlobalBoolean(NETSTATS_AUGMENT_ENABLED, true);
+        }
+        @Override
         public Config getDevConfig() {
             return new Config(getGlobalLong(NETSTATS_DEV_BUCKET_DURATION, HOUR_IN_MILLIS),
                     getGlobalLong(NETSTATS_DEV_ROTATE_AGE, 15 * DAY_IN_MILLIS),