[MS81] Support remove history before cutoff timestamp

This is needed to ensure corrupted data can be clean up if
the data migration process dones't go well.

Test: NetworkStatsCollectionTest
Bug: 197717846
Change-Id: Ic76ad6f3e96f03791b48988fb2622c9c647ffc7c
(cherry picked from commit 306a00316cac03a0c61f995316c9c5682bec2a19)
Merged-In: Ic76ad6f3e96f03791b48988fb2622c9c647ffc7c
diff --git a/framework-t/src/android/net/NetworkStatsCollection.java b/framework-t/src/android/net/NetworkStatsCollection.java
index b59a890..29ea772 100644
--- a/framework-t/src/android/net/NetworkStatsCollection.java
+++ b/framework-t/src/android/net/NetworkStatsCollection.java
@@ -694,6 +694,26 @@
         }
     }
 
+    /**
+     * Remove histories which contains or is before the cutoff timestamp.
+     * @hide
+     */
+    public void removeHistoryBefore(long cutoffMillis) {
+        final ArrayList<Key> knownKeys = new ArrayList<>();
+        knownKeys.addAll(mStats.keySet());
+
+        for (Key key : knownKeys) {
+            final NetworkStatsHistory history = mStats.get(key);
+            if (history.getStart() > cutoffMillis) continue;
+
+            history.removeBucketsStartingBefore(cutoffMillis);
+            if (history.size() == 0) {
+                mStats.remove(key);
+            }
+            mDirty = true;
+        }
+    }
+
     private void noteRecordedHistory(long startMillis, long endMillis, long totalBytes) {
         if (startMillis < mStartMillis) mStartMillis = startMillis;
         if (endMillis > mEndMillis) mEndMillis = endMillis;
diff --git a/framework-t/src/android/net/NetworkStatsHistory.java b/framework-t/src/android/net/NetworkStatsHistory.java
index 301fef9..b45d44d 100644
--- a/framework-t/src/android/net/NetworkStatsHistory.java
+++ b/framework-t/src/android/net/NetworkStatsHistory.java
@@ -680,19 +680,21 @@
     }
 
     /**
-     * Remove buckets older than requested cutoff.
+     * Remove buckets that start older than requested cutoff.
+     *
+     * This method will remove any bucket that contains any data older than the requested
+     * cutoff, even if that same bucket includes some data from after the cutoff.
+     *
      * @hide
      */
-    public void removeBucketsBefore(long cutoff) {
+    public void removeBucketsStartingBefore(final long cutoff) {
         // TODO: Consider use getIndexBefore.
         int i;
         for (i = 0; i < bucketCount; i++) {
             final long curStart = bucketStart[i];
-            final long curEnd = curStart + bucketDuration;
 
-            // cutoff happens before or during this bucket; everything before
-            // this bucket should be removed.
-            if (curEnd > cutoff) break;
+            // This bucket starts after or at the cutoff, so it should be kept.
+            if (curStart >= cutoff) break;
         }
 
         if (i > 0) {
diff --git a/service-t/src/com/android/server/net/NetworkStatsRecorder.java b/service-t/src/com/android/server/net/NetworkStatsRecorder.java
index f62765d..6f070d7 100644
--- a/service-t/src/com/android/server/net/NetworkStatsRecorder.java
+++ b/service-t/src/com/android/server/net/NetworkStatsRecorder.java
@@ -455,6 +455,73 @@
         }
     }
 
+    /**
+     * Rewriter that will remove any histories or persisted data points before the
+     * specified cutoff time, only writing data back when modified.
+     */
+    public static class RemoveDataBeforeRewriter implements FileRotator.Rewriter {
+        private final NetworkStatsCollection mTemp;
+        private final long mCutoffMills;
+
+        public RemoveDataBeforeRewriter(long bucketDuration, long cutoffMills) {
+            mTemp = new NetworkStatsCollection(bucketDuration);
+            mCutoffMills = cutoffMills;
+        }
+
+        @Override
+        public void reset() {
+            mTemp.reset();
+        }
+
+        @Override
+        public void read(InputStream in) throws IOException {
+            mTemp.read(in);
+            mTemp.clearDirty();
+            mTemp.removeHistoryBefore(mCutoffMills);
+        }
+
+        @Override
+        public boolean shouldWrite() {
+            return mTemp.isDirty();
+        }
+
+        @Override
+        public void write(OutputStream out) throws IOException {
+            mTemp.write(out);
+        }
+    }
+
+    /**
+     * Remove persisted data which contains or is before the cutoff timestamp.
+     */
+    public void removeDataBefore(long cutoffMillis) throws IOException {
+        if (mRotator != null) {
+            try {
+                mRotator.rewriteAll(new RemoveDataBeforeRewriter(
+                        mBucketDuration, cutoffMillis));
+            } catch (IOException e) {
+                Log.wtf(TAG, "problem importing netstats", e);
+                recoverFromWtf();
+            } catch (OutOfMemoryError e) {
+                Log.wtf(TAG, "problem importing netstats", e);
+                recoverFromWtf();
+            }
+        }
+
+        // Clean up any pending stats
+        if (mPending != null) {
+            mPending.removeHistoryBefore(cutoffMillis);
+        }
+        if (mSinceBoot != null) {
+            mSinceBoot.removeHistoryBefore(cutoffMillis);
+        }
+
+        final NetworkStatsCollection complete = mComplete != null ? mComplete.get() : null;
+        if (complete != null) {
+            complete.removeHistoryBefore(cutoffMillis);
+        }
+    }
+
     public void dumpLocked(IndentingPrintWriter pw, boolean fullHistory) {
         if (mPending != null) {
             pw.print("Pending bytes: "); pw.println(mPending.getTotalBytes());
diff --git a/tests/unit/java/android/net/NetworkStatsCollectionTest.java b/tests/unit/java/android/net/NetworkStatsCollectionTest.java
index 0f02850..b518a61 100644
--- a/tests/unit/java/android/net/NetworkStatsCollectionTest.java
+++ b/tests/unit/java/android/net/NetworkStatsCollectionTest.java
@@ -37,12 +37,15 @@
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 
+import android.annotation.NonNull;
 import android.content.res.Resources;
+import android.net.NetworkStatsCollection.Key;
 import android.os.Process;
 import android.os.UserHandle;
 import android.telephony.SubscriptionPlan;
 import android.telephony.TelephonyManager;
 import android.text.format.DateUtils;
+import android.util.ArrayMap;
 import android.util.RecurrenceRule;
 
 import androidx.test.InstrumentationRegistry;
@@ -73,6 +76,8 @@
 import java.time.ZonedDateTime;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 /**
  * Tests for {@link NetworkStatsCollection}.
@@ -531,6 +536,86 @@
         assertThrows(ArithmeticException.class, () -> multiplySafeByRational(30, 3, 0));
     }
 
+    private static void assertCollectionEntries(
+            @NonNull Map<Key, NetworkStatsHistory> expectedEntries,
+            @NonNull NetworkStatsCollection collection) {
+        final Map<Key, NetworkStatsHistory> actualEntries = collection.getEntries();
+        assertEquals(expectedEntries.size(), actualEntries.size());
+        for (Key expectedKey : expectedEntries.keySet()) {
+            final NetworkStatsHistory expectedHistory = expectedEntries.get(expectedKey);
+            final NetworkStatsHistory actualHistory = actualEntries.get(expectedKey);
+            assertNotNull(actualHistory);
+            assertEquals(expectedHistory.getEntries(), actualHistory.getEntries());
+            actualEntries.remove(expectedKey);
+        }
+        assertEquals(0, actualEntries.size());
+    }
+
+    @Test
+    public void testRemoveHistoryBefore() {
+        final NetworkIdentity testIdent = new NetworkIdentity.Builder()
+                .setSubscriberId(TEST_IMSI).build();
+        final Key key1 = new Key(Set.of(testIdent), 0, 0, 0);
+        final Key key2 = new Key(Set.of(testIdent), 1, 0, 0);
+        final long bucketDuration = 10;
+
+        // Prepare entries for testing, with different bucket start timestamps.
+        final NetworkStatsHistory.Entry entry1 = new NetworkStatsHistory.Entry(10, 10, 40,
+                4, 50, 5, 60);
+        final NetworkStatsHistory.Entry entry2 = new NetworkStatsHistory.Entry(20, 10, 3,
+                41, 7, 1, 0);
+        final NetworkStatsHistory.Entry entry3 = new NetworkStatsHistory.Entry(30, 10, 1,
+                21, 70, 4, 1);
+
+        NetworkStatsHistory history1 = new NetworkStatsHistory.Builder(10, 5)
+                .addEntry(entry1)
+                .addEntry(entry2)
+                .build();
+        NetworkStatsHistory history2 = new NetworkStatsHistory.Builder(10, 5)
+                .addEntry(entry2)
+                .addEntry(entry3)
+                .build();
+        NetworkStatsCollection collection = new NetworkStatsCollection.Builder(bucketDuration)
+                .addEntry(key1, history1)
+                .addEntry(key2, history2)
+                .build();
+
+        // Verify nothing is removed if the cutoff time is equal to bucketStart.
+        collection.removeHistoryBefore(10);
+        final Map<Key, NetworkStatsHistory> expectedEntries = new ArrayMap<>();
+        expectedEntries.put(key1, history1);
+        expectedEntries.put(key2, history2);
+        assertCollectionEntries(expectedEntries, collection);
+
+        // Verify entry1 will be removed if its bucket start before to cutoff timestamp.
+        collection.removeHistoryBefore(11);
+        history1 = new NetworkStatsHistory.Builder(10, 5)
+                .addEntry(entry2)
+                .build();
+        history2 = new NetworkStatsHistory.Builder(10, 5)
+                .addEntry(entry2)
+                .addEntry(entry3)
+                .build();
+        final Map<Key, NetworkStatsHistory> cutoff1Entries1 = new ArrayMap<>();
+        cutoff1Entries1.put(key1, history1);
+        cutoff1Entries1.put(key2, history2);
+        assertCollectionEntries(cutoff1Entries1, collection);
+
+        // Verify entry2 will be removed if its bucket start covers by cutoff timestamp.
+        collection.removeHistoryBefore(22);
+        history2 = new NetworkStatsHistory.Builder(10, 5)
+                .addEntry(entry3)
+                .build();
+        final Map<Key, NetworkStatsHistory> cutoffEntries2 = new ArrayMap<>();
+        // History1 is not expected since the collection will omit empty entries.
+        cutoffEntries2.put(key2, history2);
+        assertCollectionEntries(cutoffEntries2, collection);
+
+        // Verify all entries will be removed if cutoff timestamp covers all.
+        collection.removeHistoryBefore(Long.MAX_VALUE);
+        assertEquals(0, collection.getEntries().size());
+    }
+
     /**
      * Copy a {@link Resources#openRawResource(int)} into {@link File} for
      * testing purposes.
diff --git a/tests/unit/java/android/net/NetworkStatsHistoryTest.java b/tests/unit/java/android/net/NetworkStatsHistoryTest.java
index c5f8c00..26079a2 100644
--- a/tests/unit/java/android/net/NetworkStatsHistoryTest.java
+++ b/tests/unit/java/android/net/NetworkStatsHistoryTest.java
@@ -270,7 +270,7 @@
     }
 
     @Test
-    public void testRemove() throws Exception {
+    public void testRemoveStartingBefore() throws Exception {
         stats = new NetworkStatsHistory(HOUR_IN_MILLIS);
 
         // record some data across 24 buckets
@@ -278,28 +278,28 @@
         assertEquals(24, stats.size());
 
         // try removing invalid data; should be no change
-        stats.removeBucketsBefore(0 - DAY_IN_MILLIS);
+        stats.removeBucketsStartingBefore(0 - DAY_IN_MILLIS);
         assertEquals(24, stats.size());
 
         // try removing far before buckets; should be no change
-        stats.removeBucketsBefore(TEST_START - YEAR_IN_MILLIS);
+        stats.removeBucketsStartingBefore(TEST_START - YEAR_IN_MILLIS);
         assertEquals(24, stats.size());
 
         // try removing just moments into first bucket; should be no change
-        // since that bucket contains data beyond the cutoff
-        stats.removeBucketsBefore(TEST_START + SECOND_IN_MILLIS);
+        // since that bucket doesn't contain data starts before the cutoff
+        stats.removeBucketsStartingBefore(TEST_START);
         assertEquals(24, stats.size());
 
         // try removing single bucket
-        stats.removeBucketsBefore(TEST_START + HOUR_IN_MILLIS);
+        stats.removeBucketsStartingBefore(TEST_START + HOUR_IN_MILLIS);
         assertEquals(23, stats.size());
 
         // try removing multiple buckets
-        stats.removeBucketsBefore(TEST_START + (4 * HOUR_IN_MILLIS));
+        stats.removeBucketsStartingBefore(TEST_START + (4 * HOUR_IN_MILLIS));
         assertEquals(20, stats.size());
 
         // try removing all buckets
-        stats.removeBucketsBefore(TEST_START + YEAR_IN_MILLIS);
+        stats.removeBucketsStartingBefore(TEST_START + YEAR_IN_MILLIS);
         assertEquals(0, stats.size());
     }
 
@@ -349,7 +349,7 @@
                         stats.recordData(start, end, entry);
                     } else {
                         // trim something
-                        stats.removeBucketsBefore(r.nextLong());
+                        stats.removeBucketsStartingBefore(r.nextLong());
                     }
                 }
                 assertConsistent(stats);