Merge "TetheringManager API clean up"
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index df92be8..0f48ac2 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -3694,29 +3694,29 @@
     /**
      * Helper function to request a network with a particular legacy type.
      *
-     * @deprecated This is temporarily public for tethering to backwards compatibility that uses
-     * the NetworkRequest API to request networks with legacy type and relies on
-     * CONNECTIVITY_ACTION broadcasts instead of NetworkCallbacks. New caller should use
+     * This API is only for use in internal system code that requests networks with legacy type and
+     * relies on CONNECTIVITY_ACTION broadcasts instead of NetworkCallbacks. New caller should use
      * {@link #requestNetwork(NetworkRequest, NetworkCallback, Handler)} instead.
      *
-     * TODO: update said system code to rely on NetworkCallbacks and make this method private.
-
      * @param request {@link NetworkRequest} describing this request.
-     * @param networkCallback The {@link NetworkCallback} to be utilized for this request. Note
-     *                        the callback must not be shared - it uniquely specifies this request.
      * @param timeoutMs The time in milliseconds to attempt looking for a suitable network
      *                  before {@link NetworkCallback#onUnavailable()} is called. The timeout must
      *                  be a positive value (i.e. >0).
      * @param legacyType to specify the network type(#TYPE_*).
      * @param handler {@link Handler} to specify the thread upon which the callback will be invoked.
+     * @param networkCallback The {@link NetworkCallback} to be utilized for this request. Note
+     *                        the callback must not be shared - it uniquely specifies this request.
      *
      * @hide
      */
     @SystemApi
-    @Deprecated
+    @RequiresPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK)
     public void requestNetwork(@NonNull NetworkRequest request,
-            @NonNull NetworkCallback networkCallback, int timeoutMs, int legacyType,
-            @NonNull Handler handler) {
+            int timeoutMs, int legacyType, @NonNull Handler handler,
+            @NonNull NetworkCallback networkCallback) {
+        if (legacyType == TYPE_NONE) {
+            throw new IllegalArgumentException("TYPE_NONE is meaningless legacy type");
+        }
         CallbackHandler cbHandler = new CallbackHandler(handler);
         NetworkCapabilities nc = request.networkCapabilities;
         sendRequestForNetwork(nc, networkCallback, timeoutMs, REQUEST, legacyType, cbHandler);
@@ -3815,7 +3815,8 @@
     public void requestNetwork(@NonNull NetworkRequest request,
             @NonNull NetworkCallback networkCallback, @NonNull Handler handler) {
         CallbackHandler cbHandler = new CallbackHandler(handler);
-        requestNetwork(request, networkCallback, 0, TYPE_NONE, cbHandler);
+        NetworkCapabilities nc = request.networkCapabilities;
+        sendRequestForNetwork(nc, networkCallback, 0, REQUEST, TYPE_NONE, cbHandler);
     }
 
     /**
@@ -3848,7 +3849,9 @@
     public void requestNetwork(@NonNull NetworkRequest request,
             @NonNull NetworkCallback networkCallback, int timeoutMs) {
         checkTimeout(timeoutMs);
-        requestNetwork(request, networkCallback, timeoutMs, TYPE_NONE, getDefaultHandler());
+        NetworkCapabilities nc = request.networkCapabilities;
+        sendRequestForNetwork(nc, networkCallback, timeoutMs, REQUEST, TYPE_NONE,
+                getDefaultHandler());
     }
 
     /**
@@ -3874,7 +3877,8 @@
             @NonNull NetworkCallback networkCallback, @NonNull Handler handler, int timeoutMs) {
         checkTimeout(timeoutMs);
         CallbackHandler cbHandler = new CallbackHandler(handler);
-        requestNetwork(request, networkCallback, timeoutMs, TYPE_NONE, cbHandler);
+        NetworkCapabilities nc = request.networkCapabilities;
+        sendRequestForNetwork(nc, networkCallback, timeoutMs, REQUEST, TYPE_NONE, cbHandler);
     }
 
     /**
diff --git a/core/java/android/net/TestNetworkManager.java b/core/java/android/net/TestNetworkManager.java
index c3284df..a0a563b 100644
--- a/core/java/android/net/TestNetworkManager.java
+++ b/core/java/android/net/TestNetworkManager.java
@@ -30,6 +30,18 @@
  */
 @TestApi
 public class TestNetworkManager {
+    /**
+     * Prefix for tun interfaces created by this class.
+     * @hide
+     */
+    public static final String TEST_TUN_PREFIX = "testtun";
+
+    /**
+     * Prefix for tap interfaces created by this class.
+     * @hide
+     */
+    public static final String TEST_TAP_PREFIX = "testtap";
+
     @NonNull private static final String TAG = TestNetworkManager.class.getSimpleName();
 
     @NonNull private final ITestNetworkManager mService;
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index c7ad3e4..dc86c9a 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -5393,6 +5393,9 @@
     public NetworkRequest requestNetwork(NetworkCapabilities networkCapabilities,
             Messenger messenger, int timeoutMs, IBinder binder, int legacyType,
             @NonNull String callingPackageName) {
+        if (legacyType != TYPE_NONE && !checkNetworkStackPermission()) {
+            throw new SecurityException("Insufficient permissions to specify legacy type");
+        }
         final int callingUid = Binder.getCallingUid();
         final NetworkRequest.Type type = (networkCapabilities == null)
                 ? NetworkRequest.Type.TRACK_DEFAULT
diff --git a/services/core/java/com/android/server/TestNetworkService.java b/services/core/java/com/android/server/TestNetworkService.java
index f772a4a..5045361 100644
--- a/services/core/java/com/android/server/TestNetworkService.java
+++ b/services/core/java/com/android/server/TestNetworkService.java
@@ -16,6 +16,9 @@
 
 package com.android.server;
 
+import static android.net.TestNetworkManager.TEST_TAP_PREFIX;
+import static android.net.TestNetworkManager.TEST_TUN_PREFIX;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -61,8 +64,6 @@
 class TestNetworkService extends ITestNetworkManager.Stub {
     @NonNull private static final String TAG = TestNetworkService.class.getSimpleName();
     @NonNull private static final String TEST_NETWORK_TYPE = "TEST_NETWORK";
-    @NonNull private static final String TEST_TUN_PREFIX = "testtun";
-    @NonNull private static final String TEST_TAP_PREFIX = "testtap";
     @NonNull private static final AtomicInteger sTestTunIndex = new AtomicInteger();
 
     @NonNull private final Context mContext;
diff --git a/tests/net/java/android/net/NetworkTemplateTest.kt b/tests/net/java/android/net/NetworkTemplateTest.kt
new file mode 100644
index 0000000..5dd0fda
--- /dev/null
+++ b/tests/net/java/android/net/NetworkTemplateTest.kt
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.content.Context
+import android.net.ConnectivityManager.TYPE_MOBILE
+import android.net.ConnectivityManager.TYPE_WIFI
+import android.net.NetworkIdentity.SUBTYPE_COMBINED
+import android.net.NetworkIdentity.buildNetworkIdentity
+import android.net.NetworkStats.DEFAULT_NETWORK_ALL
+import android.net.NetworkStats.METERED_ALL
+import android.net.NetworkStats.ROAMING_ALL
+import android.net.NetworkTemplate.MATCH_MOBILE
+import android.net.NetworkTemplate.MATCH_WIFI
+import android.net.NetworkTemplate.NETWORK_TYPE_ALL
+import android.net.NetworkTemplate.buildTemplateMobileWithRatType
+import android.telephony.TelephonyManager
+import com.android.testutils.assertParcelSane
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.MockitoAnnotations
+import kotlin.test.assertFalse
+import kotlin.test.assertNotEquals
+import kotlin.test.assertTrue
+
+private const val TEST_IMSI1 = "imsi1"
+private const val TEST_IMSI2 = "imsi2"
+private const val TEST_SSID1 = "ssid1"
+
+@RunWith(JUnit4::class)
+class NetworkTemplateTest {
+    private val mockContext = mock(Context::class.java)
+
+    private fun buildMobileNetworkState(subscriberId: String): NetworkState =
+            buildNetworkState(TYPE_MOBILE, subscriberId = subscriberId)
+    private fun buildWifiNetworkState(ssid: String): NetworkState =
+            buildNetworkState(TYPE_WIFI, ssid = ssid)
+
+    private fun buildNetworkState(
+        type: Int,
+        subscriberId: String? = null,
+        ssid: String? = null
+    ): NetworkState {
+        val info = mock(NetworkInfo::class.java)
+        doReturn(type).`when`(info).type
+        doReturn(NetworkInfo.State.CONNECTED).`when`(info).state
+        val lp = LinkProperties()
+        val caps = NetworkCapabilities().apply {
+            setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED, false)
+            setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING, true)
+        }
+        return NetworkState(info, lp, caps, mock(Network::class.java), subscriberId, ssid)
+    }
+
+    private fun NetworkTemplate.assertMatches(ident: NetworkIdentity) =
+            assertTrue(matches(ident), "$this does not match $ident")
+
+    private fun NetworkTemplate.assertDoesNotMatch(ident: NetworkIdentity) =
+            assertFalse(matches(ident), "$this should match $ident")
+
+    @Before
+    fun setup() {
+        MockitoAnnotations.initMocks(this)
+    }
+
+    @Test
+    fun testRatTypeGroupMatches() {
+        val stateMobile = buildMobileNetworkState(TEST_IMSI1)
+        // Build UMTS template that matches mobile identities with RAT in the same
+        // group with any IMSI. See {@link NetworkTemplate#getCollapsedRatType}.
+        val templateUmts = buildTemplateMobileWithRatType(null, TelephonyManager.NETWORK_TYPE_UMTS)
+        // Build normal template that matches mobile identities with any RAT and IMSI.
+        val templateAll = buildTemplateMobileWithRatType(null, NETWORK_TYPE_ALL)
+        // Build template with UNKNOWN RAT that matches mobile identities with RAT that
+        // cannot be determined.
+        val templateUnknown =
+                buildTemplateMobileWithRatType(null, TelephonyManager.NETWORK_TYPE_UNKNOWN)
+
+        val identUmts = buildNetworkIdentity(
+                mockContext, stateMobile, false, TelephonyManager.NETWORK_TYPE_UMTS)
+        val identHsdpa = buildNetworkIdentity(
+                mockContext, stateMobile, false, TelephonyManager.NETWORK_TYPE_HSDPA)
+        val identLte = buildNetworkIdentity(
+                mockContext, stateMobile, false, TelephonyManager.NETWORK_TYPE_LTE)
+        val identCombined = buildNetworkIdentity(
+                mockContext, stateMobile, false, SUBTYPE_COMBINED)
+        val identImsi2 = buildNetworkIdentity(mockContext, buildMobileNetworkState(TEST_IMSI2),
+                false, TelephonyManager.NETWORK_TYPE_UMTS)
+        val identWifi = buildNetworkIdentity(
+                mockContext, buildWifiNetworkState(TEST_SSID1), true, 0)
+
+        // Assert that identity with the same RAT matches.
+        templateUmts.assertMatches(identUmts)
+        templateAll.assertMatches(identUmts)
+        templateUnknown.assertDoesNotMatch(identUmts)
+        // Assert that identity with the RAT within the same group matches.
+        templateUmts.assertMatches(identHsdpa)
+        templateAll.assertMatches(identHsdpa)
+        templateUnknown.assertDoesNotMatch(identHsdpa)
+        // Assert that identity with the RAT out of the same group only matches template with
+        // NETWORK_TYPE_ALL.
+        templateUmts.assertDoesNotMatch(identLte)
+        templateAll.assertMatches(identLte)
+        templateUnknown.assertDoesNotMatch(identLte)
+        // Assert that identity with combined RAT only matches with template with NETWORK_TYPE_ALL
+        // and NETWORK_TYPE_UNKNOWN.
+        templateUmts.assertDoesNotMatch(identCombined)
+        templateAll.assertMatches(identCombined)
+        templateUnknown.assertMatches(identCombined)
+        // Assert that identity with different IMSI matches.
+        templateUmts.assertMatches(identImsi2)
+        templateAll.assertMatches(identImsi2)
+        templateUnknown.assertDoesNotMatch(identImsi2)
+        // Assert that wifi identity does not match.
+        templateUmts.assertDoesNotMatch(identWifi)
+        templateAll.assertDoesNotMatch(identWifi)
+        templateUnknown.assertDoesNotMatch(identWifi)
+    }
+
+    @Test
+    fun testParcelUnparcel() {
+        val templateMobile = NetworkTemplate(MATCH_MOBILE, TEST_IMSI1, null, null, METERED_ALL,
+                ROAMING_ALL, DEFAULT_NETWORK_ALL, TelephonyManager.NETWORK_TYPE_LTE)
+        val templateWifi = NetworkTemplate(MATCH_WIFI, null, null, TEST_SSID1, METERED_ALL,
+                ROAMING_ALL, DEFAULT_NETWORK_ALL, 0)
+        assertParcelSane(templateMobile, 8)
+        assertParcelSane(templateWifi, 8)
+    }
+
+    // Verify NETWORK_TYPE_ALL does not conflict with TelephonyManager#NETWORK_TYPE_* constants.
+    @Test
+    fun testNetworkTypeAll() {
+        for (ratType in TelephonyManager.getAllNetworkTypes()) {
+            assertNotEquals(NETWORK_TYPE_ALL, ratType)
+        }
+    }
+}
diff --git a/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java b/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java
index 8f90f13..551498f 100644
--- a/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java
+++ b/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java
@@ -319,33 +319,33 @@
             assertEntry(18322, 75, 15031, 75, history.getValues(i++, null));
             assertEntry(527798, 761, 78570, 652, history.getValues(i++, null));
             assertEntry(527797, 760, 78570, 651, history.getValues(i++, null));
-            assertEntry(10747, 50, 16838, 55, history.getValues(i++, null));
-            assertEntry(10747, 49, 16838, 54, history.getValues(i++, null));
+            assertEntry(10747, 50, 16839, 55, history.getValues(i++, null));
+            assertEntry(10747, 49, 16837, 54, history.getValues(i++, null));
             assertEntry(89191, 151, 18021, 140, history.getValues(i++, null));
             assertEntry(89190, 150, 18020, 139, history.getValues(i++, null));
-            assertEntry(3821, 22, 4525, 26, history.getValues(i++, null));
-            assertEntry(3820, 22, 4524, 26, history.getValues(i++, null));
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
+            assertEntry(3821, 23, 4525, 26, history.getValues(i++, null));
+            assertEntry(3820, 21, 4524, 26, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
+            assertEntry(8289, 36, 6864, 39, history.getValues(i++, null));
+            assertEntry(8289, 34, 6862, 37, history.getValues(i++, null));
             assertEntry(113914, 174, 18364, 157, history.getValues(i++, null));
             assertEntry(113913, 173, 18364, 157, history.getValues(i++, null));
-            assertEntry(11378, 49, 9261, 49, history.getValues(i++, null));
-            assertEntry(11377, 48, 9261, 49, history.getValues(i++, null));
-            assertEntry(201765, 328, 41808, 291, history.getValues(i++, null));
-            assertEntry(201765, 328, 41807, 290, history.getValues(i++, null));
-            assertEntry(106106, 218, 39917, 201, history.getValues(i++, null));
-            assertEntry(106105, 217, 39917, 201, history.getValues(i++, null));
+            assertEntry(11378, 49, 9261, 50, history.getValues(i++, null));
+            assertEntry(11377, 48, 9261, 48, history.getValues(i++, null));
+            assertEntry(201766, 328, 41808, 291, history.getValues(i++, null));
+            assertEntry(201764, 328, 41807, 290, history.getValues(i++, null));
+            assertEntry(106106, 219, 39918, 202, history.getValues(i++, null));
+            assertEntry(106105, 216, 39916, 200, history.getValues(i++, null));
             assertEquals(history.size(), i);
 
             // Slice from middle should be untouched
             history = getHistory(collection, plan, TIME_B - HOUR_IN_MILLIS,
                     TIME_B + HOUR_IN_MILLIS); i = 0;
-            assertEntry(3821, 22, 4525, 26, history.getValues(i++, null));
-            assertEntry(3820, 22, 4524, 26, history.getValues(i++, null));
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
+            assertEntry(3821, 23, 4525, 26, history.getValues(i++, null));
+            assertEntry(3820, 21, 4524, 26, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
             assertEquals(history.size(), i);
         }
 
@@ -373,25 +373,25 @@
             assertEntry(527797, 760, 78570, 651, history.getValues(i++, null));
             // Cycle point; start data normalization
             assertEntry(7507, 0, 11763, 0, history.getValues(i++, null));
-            assertEntry(7507, 0, 11763, 0, history.getValues(i++, null));
+            assertEntry(7507, 0, 11762, 0, history.getValues(i++, null));
             assertEntry(62309, 0, 12589, 0, history.getValues(i++, null));
             assertEntry(62309, 0, 12588, 0, history.getValues(i++, null));
             assertEntry(2669, 0, 3161, 0, history.getValues(i++, null));
             assertEntry(2668, 0, 3160, 0, history.getValues(i++, null));
             // Anchor point; end data normalization
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
+            assertEntry(8289, 36, 6864, 39, history.getValues(i++, null));
+            assertEntry(8289, 34, 6862, 37, history.getValues(i++, null));
             assertEntry(113914, 174, 18364, 157, history.getValues(i++, null));
             assertEntry(113913, 173, 18364, 157, history.getValues(i++, null));
             // Cycle point
-            assertEntry(11378, 49, 9261, 49, history.getValues(i++, null));
-            assertEntry(11377, 48, 9261, 49, history.getValues(i++, null));
-            assertEntry(201765, 328, 41808, 291, history.getValues(i++, null));
-            assertEntry(201765, 328, 41807, 290, history.getValues(i++, null));
-            assertEntry(106106, 218, 39917, 201, history.getValues(i++, null));
-            assertEntry(106105, 217, 39917, 201, history.getValues(i++, null));
+            assertEntry(11378, 49, 9261, 50, history.getValues(i++, null));
+            assertEntry(11377, 48, 9261, 48, history.getValues(i++, null));
+            assertEntry(201766, 328, 41808, 291, history.getValues(i++, null));
+            assertEntry(201764, 328, 41807, 290, history.getValues(i++, null));
+            assertEntry(106106, 219, 39918, 202, history.getValues(i++, null));
+            assertEntry(106105, 216, 39916, 200, history.getValues(i++, null));
             assertEquals(history.size(), i);
 
             // Slice from middle should be augmented
@@ -399,8 +399,8 @@
                     TIME_B + HOUR_IN_MILLIS); i = 0;
             assertEntry(2669, 0, 3161, 0, history.getValues(i++, null));
             assertEntry(2668, 0, 3160, 0, history.getValues(i++, null));
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
             assertEquals(history.size(), i);
         }
 
@@ -427,34 +427,34 @@
             assertEntry(527798, 761, 78570, 652, history.getValues(i++, null));
             assertEntry(527797, 760, 78570, 651, history.getValues(i++, null));
             // Cycle point; start data normalization
-            assertEntry(15015, 0, 23526, 0, history.getValues(i++, null));
-            assertEntry(15015, 0, 23526, 0, history.getValues(i++, null));
+            assertEntry(15015, 0, 23527, 0, history.getValues(i++, null));
+            assertEntry(15015, 0, 23524, 0, history.getValues(i++, null));
             assertEntry(124619, 0, 25179, 0, history.getValues(i++, null));
             assertEntry(124618, 0, 25177, 0, history.getValues(i++, null));
             assertEntry(5338, 0, 6322, 0, history.getValues(i++, null));
             assertEntry(5337, 0, 6320, 0, history.getValues(i++, null));
             // Anchor point; end data normalization
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
-            assertEntry(8289, 35, 6863, 38, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
+            assertEntry(8289, 36, 6864, 39, history.getValues(i++, null));
+            assertEntry(8289, 34, 6862, 37, history.getValues(i++, null));
             assertEntry(113914, 174, 18364, 157, history.getValues(i++, null));
             assertEntry(113913, 173, 18364, 157, history.getValues(i++, null));
             // Cycle point
-            assertEntry(11378, 49, 9261, 49, history.getValues(i++, null));
-            assertEntry(11377, 48, 9261, 49, history.getValues(i++, null));
-            assertEntry(201765, 328, 41808, 291, history.getValues(i++, null));
-            assertEntry(201765, 328, 41807, 290, history.getValues(i++, null));
-            assertEntry(106106, 218, 39917, 201, history.getValues(i++, null));
-            assertEntry(106105, 217, 39917, 201, history.getValues(i++, null));
+            assertEntry(11378, 49, 9261, 50, history.getValues(i++, null));
+            assertEntry(11377, 48, 9261, 48, history.getValues(i++, null));
+            assertEntry(201766, 328, 41808, 291, history.getValues(i++, null));
+            assertEntry(201764, 328, 41807, 290, history.getValues(i++, null));
+            assertEntry(106106, 219, 39918, 202, history.getValues(i++, null));
+            assertEntry(106105, 216, 39916, 200, history.getValues(i++, null));
 
             // Slice from middle should be augmented
             history = getHistory(collection, plan, TIME_B - HOUR_IN_MILLIS,
                     TIME_B + HOUR_IN_MILLIS); i = 0;
             assertEntry(5338, 0, 6322, 0, history.getValues(i++, null));
             assertEntry(5337, 0, 6320, 0, history.getValues(i++, null));
-            assertEntry(91686, 159, 18575, 146, history.getValues(i++, null));
-            assertEntry(91685, 159, 18575, 146, history.getValues(i++, null));
+            assertEntry(91686, 159, 18576, 146, history.getValues(i++, null));
+            assertEntry(91685, 159, 18574, 146, history.getValues(i++, null));
             assertEquals(history.size(), i);
         }
     }
diff --git a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
index b346c92..6e63313 100644
--- a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -42,6 +42,7 @@
 import static android.net.NetworkStats.UID_ALL;
 import static android.net.NetworkStatsHistory.FIELD_ALL;
 import static android.net.NetworkTemplate.buildTemplateMobileAll;
+import static android.net.NetworkTemplate.buildTemplateMobileWithRatType;
 import static android.net.NetworkTemplate.buildTemplateWifiWildcard;
 import static android.net.TrafficStats.MB_IN_BYTES;
 import static android.net.TrafficStats.UID_REMOVED;
@@ -60,11 +61,13 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.app.AlarmManager;
 import android.app.usage.NetworkStatsManager;
 import android.content.Context;
@@ -92,6 +95,8 @@
 import android.os.Messenger;
 import android.os.PowerManager;
 import android.os.SimpleClock;
+import android.telephony.PhoneStateListener;
+import android.telephony.ServiceState;
 import android.telephony.TelephonyManager;
 
 import androidx.test.InstrumentationRegistry;
@@ -163,11 +168,14 @@
     private @Mock NetworkStatsSettings mSettings;
     private @Mock IBinder mBinder;
     private @Mock AlarmManager mAlarmManager;
+    private @Mock TelephonyManager mTelephonyManager;
     private HandlerThread mHandlerThread;
 
     private NetworkStatsService mService;
     private INetworkStatsSession mSession;
     private INetworkManagementEventObserver mNetworkObserver;
+    @Nullable
+    private PhoneStateListener mPhoneStateListener;
 
     private final Clock mClock = new SimpleClock(ZoneOffset.UTC) {
         @Override
@@ -195,7 +203,7 @@
         mHandlerThread = new HandlerThread("HandlerThread");
         final NetworkStatsService.Dependencies deps = makeDependencies();
         mService = new NetworkStatsService(mServiceContext, mNetManager, mAlarmManager, wakeLock,
-                mClock, mServiceContext.getSystemService(TelephonyManager.class), mSettings,
+                mClock, mTelephonyManager, mSettings,
                 mStatsFactory, new NetworkStatsObservers(), mStatsDir, getBaseDir(mStatsDir), deps);
 
         mElapsedRealtime = 0L;
@@ -216,6 +224,12 @@
                 ArgumentCaptor.forClass(INetworkManagementEventObserver.class);
         verify(mNetManager).registerObserver(networkObserver.capture());
         mNetworkObserver = networkObserver.getValue();
+
+        // Capture the phone state listener that created by service.
+        final ArgumentCaptor<PhoneStateListener> phoneStateListenerCaptor =
+                ArgumentCaptor.forClass(PhoneStateListener.class);
+        verify(mTelephonyManager).listen(phoneStateListenerCaptor.capture(), anyInt());
+        mPhoneStateListener = phoneStateListenerCaptor.getValue();
     }
 
     @NonNull
@@ -534,7 +548,7 @@
     }
 
     @Test
-    public void testUid3g4gCombinedByTemplate() throws Exception {
+    public void testUid3gWimaxCombinedByTemplate() throws Exception {
         // pretend that network comes online
         expectDefaultSettings();
         NetworkState[] states = new NetworkState[] {buildMobile3gState(IMSI_1)};
@@ -558,10 +572,10 @@
         assertUidTotal(sTemplateImsi1, UID_RED, 1024L, 8L, 1024L, 8L, 5);
 
 
-        // now switch over to 4g network
+        // now switch over to wimax network
         incrementCurrentTime(HOUR_IN_MILLIS);
         expectDefaultSettings();
-        states = new NetworkState[] {buildMobile4gState(TEST_IFACE2)};
+        states = new NetworkState[] {buildWimaxState(TEST_IFACE2)};
         expectNetworkStatsSummary(buildEmptyStats());
         expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
                 .insertEntry(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE, 1024L, 8L, 1024L, 8L, 0L)
@@ -589,6 +603,89 @@
     }
 
     @Test
+    public void testMobileStatsByRatType() throws Exception {
+        final NetworkTemplate template3g =
+                buildTemplateMobileWithRatType(null, TelephonyManager.NETWORK_TYPE_UMTS);
+        final NetworkTemplate template4g =
+                buildTemplateMobileWithRatType(null, TelephonyManager.NETWORK_TYPE_LTE);
+        final NetworkTemplate template5g =
+                buildTemplateMobileWithRatType(null, TelephonyManager.NETWORK_TYPE_NR);
+        final NetworkState[] states = new NetworkState[]{buildMobile3gState(IMSI_1)};
+
+        // 3G network comes online.
+        expectNetworkStatsSummary(buildEmptyStats());
+        expectNetworkStatsUidDetail(buildEmptyStats());
+
+        setMobileRatTypeAndWaitForIdle(TelephonyManager.NETWORK_TYPE_UMTS);
+        mService.forceUpdateIfaces(NETWORKS_MOBILE, states, getActiveIface(states),
+                new VpnInfo[0]);
+
+        // Create some traffic.
+        incrementCurrentTime(MINUTE_IN_MILLIS);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE,
+                        12L, 18L, 14L, 1L, 0L)));
+        forcePollAndWaitForIdle();
+
+        // Verify 3g templates gets stats.
+        assertUidTotal(sTemplateImsi1, UID_RED, 12L, 18L, 14L, 1L, 0);
+        assertUidTotal(template3g, UID_RED, 12L, 18L, 14L, 1L, 0);
+        assertUidTotal(template4g, UID_RED, 0L, 0L, 0L, 0L, 0);
+        assertUidTotal(template5g, UID_RED, 0L, 0L, 0L, 0L, 0);
+
+        // 4G network comes online.
+        incrementCurrentTime(MINUTE_IN_MILLIS);
+        setMobileRatTypeAndWaitForIdle(TelephonyManager.NETWORK_TYPE_LTE);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
+                // Append more traffic on existing 3g stats entry.
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE,
+                        16L, 22L, 17L, 2L, 0L))
+                // Add entry that is new on 4g.
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_FOREGROUND, TAG_NONE,
+                        33L, 27L, 8L, 10L, 1L)));
+        forcePollAndWaitForIdle();
+
+        // Verify ALL_MOBILE template gets all. 3g template counters do not increase.
+        assertUidTotal(sTemplateImsi1, UID_RED, 49L, 49L, 25L, 12L, 1);
+        assertUidTotal(template3g, UID_RED, 12L, 18L, 14L, 1L, 0);
+        // Verify 4g template counts appended stats on existing entry and newly created entry.
+        assertUidTotal(template4g, UID_RED, 4L + 33L, 4L + 27L, 3L + 8L, 1L + 10L, 1);
+        // Verify 5g template doesn't get anything since no traffic is generated on 5g.
+        assertUidTotal(template5g, UID_RED, 0L, 0L, 0L, 0L, 0);
+
+        // 5g network comes online.
+        incrementCurrentTime(MINUTE_IN_MILLIS);
+        setMobileRatTypeAndWaitForIdle(TelephonyManager.NETWORK_TYPE_NR);
+        expectNetworkStatsUidDetail(new NetworkStats(getElapsedRealtime(), 1)
+                // Existing stats remains.
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE,
+                        16L, 22L, 17L, 2L, 0L))
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_FOREGROUND, TAG_NONE,
+                        33L, 27L, 8L, 10L, 1L))
+                // Add some traffic on 5g.
+                .addEntry(new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT, TAG_NONE,
+                5L, 13L, 31L, 9L, 2L)));
+        forcePollAndWaitForIdle();
+
+        // Verify ALL_MOBILE template gets all.
+        assertUidTotal(sTemplateImsi1, UID_RED, 54L, 62L, 56L, 21L, 3);
+        // 3g/4g template counters do not increase.
+        assertUidTotal(template3g, UID_RED, 12L, 18L, 14L, 1L, 0);
+        assertUidTotal(template4g, UID_RED, 4L + 33L, 4L + 27L, 3L + 8L, 1L + 10L, 1);
+        // Verify 5g template gets the 5g count.
+        assertUidTotal(template5g, UID_RED, 5L, 13L, 31L, 9L, 2);
+    }
+
+    // TODO: support per IMSI state
+    private void setMobileRatTypeAndWaitForIdle(int ratType) {
+        final ServiceState mockSs = mock(ServiceState.class);
+        when(mockSs.getDataNetworkType()).thenReturn(ratType);
+        mPhoneStateListener.onServiceStateChanged(mockSs);
+
+        HandlerUtilsKt.waitForIdle(mHandlerThread, WAIT_TIMEOUT);
+    }
+
+    @Test
     public void testSummaryForAllUid() throws Exception {
         // pretend that network comes online
         expectDefaultSettings();
@@ -1197,6 +1294,7 @@
         when(mSettings.getPollInterval()).thenReturn(HOUR_IN_MILLIS);
         when(mSettings.getPollDelay()).thenReturn(0L);
         when(mSettings.getSampleEnabled()).thenReturn(true);
+        when(mSettings.getCombineSubtypeEnabled()).thenReturn(false);
 
         final Config config = new Config(bucketDuration, deleteAge, deleteAge);
         when(mSettings.getDevConfig()).thenReturn(config);
@@ -1242,6 +1340,7 @@
         final NetworkCapabilities capabilities = new NetworkCapabilities();
         capabilities.setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED, !isMetered);
         capabilities.setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING, true);
+        capabilities.addTransportType(NetworkCapabilities.TRANSPORT_WIFI);
         return new NetworkState(info, prop, capabilities, WIFI_NETWORK, null, TEST_SSID);
     }
 
@@ -1259,10 +1358,11 @@
         final NetworkCapabilities capabilities = new NetworkCapabilities();
         capabilities.setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED, false);
         capabilities.setCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING, !isRoaming);
+        capabilities.addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);
         return new NetworkState(info, prop, capabilities, MOBILE_NETWORK, subscriberId, null);
     }
 
-    private static NetworkState buildMobile4gState(String iface) {
+    private static NetworkState buildWimaxState(@NonNull String iface) {
         final NetworkInfo info = new NetworkInfo(TYPE_WIMAX, 0, null, null);
         info.setDetailedState(DetailedState.CONNECTED, null, null);
         final LinkProperties prop = new LinkProperties();