Merge changes from topic "bandwidth-limiting" am: 6b5b7b40d8 am: 4eeea56e39

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/1955585

Change-Id: I059668f6f5d5c1fa2faa3e2bca2f062f78cb4314
diff --git a/framework/api/module-lib-current.txt b/framework/api/module-lib-current.txt
index 64e159b..92dc2f4 100644
--- a/framework/api/module-lib-current.txt
+++ b/framework/api/module-lib-current.txt
@@ -69,6 +69,7 @@
     method @NonNull public static java.time.Duration getDnsResolverSampleValidityDuration(@NonNull android.content.Context, @NonNull java.time.Duration);
     method public static int getDnsResolverSuccessThresholdPercent(@NonNull android.content.Context, int);
     method @Nullable public static android.net.ProxyInfo getGlobalProxy(@NonNull android.content.Context);
+    method public static long getIngressRateLimitInBytesPerSecond(@NonNull android.content.Context);
     method @NonNull public static java.time.Duration getMobileDataActivityTimeout(@NonNull android.content.Context, @NonNull java.time.Duration);
     method public static boolean getMobileDataAlwaysOn(@NonNull android.content.Context, boolean);
     method @NonNull public static java.util.Set<java.lang.Integer> getMobileDataPreferredUids(@NonNull android.content.Context);
@@ -89,6 +90,7 @@
     method public static void setDnsResolverSampleValidityDuration(@NonNull android.content.Context, @NonNull java.time.Duration);
     method public static void setDnsResolverSuccessThresholdPercent(@NonNull android.content.Context, @IntRange(from=0, to=100) int);
     method public static void setGlobalProxy(@NonNull android.content.Context, @NonNull android.net.ProxyInfo);
+    method public static void setIngressRateLimitInBytesPerSecond(@NonNull android.content.Context, @IntRange(from=0xffffffff, to=java.lang.Integer.MAX_VALUE) long);
     method public static void setMobileDataActivityTimeout(@NonNull android.content.Context, @NonNull java.time.Duration);
     method public static void setMobileDataAlwaysOn(@NonNull android.content.Context, boolean);
     method public static void setMobileDataPreferredUids(@NonNull android.content.Context, @NonNull java.util.Set<java.lang.Integer>);
diff --git a/framework/src/android/net/ConnectivitySettingsManager.java b/framework/src/android/net/ConnectivitySettingsManager.java
index 8fc0065..4e28b29 100644
--- a/framework/src/android/net/ConnectivitySettingsManager.java
+++ b/framework/src/android/net/ConnectivitySettingsManager.java
@@ -384,6 +384,14 @@
             "uids_allowed_on_restricted_networks";
 
     /**
+     * A global rate limit that applies to all networks with NET_CAPABILITY_INTERNET when enabled.
+     *
+     * @hide
+     */
+    public static final String INGRESS_RATE_LIMIT_BYTES_PER_SECOND =
+            "ingress_rate_limit_bytes_per_second";
+
+    /**
      * Get mobile data activity timeout from {@link Settings}.
      *
      * @param context The {@link Context} to query the setting.
@@ -1071,4 +1079,37 @@
         Settings.Global.putString(context.getContentResolver(), UIDS_ALLOWED_ON_RESTRICTED_NETWORKS,
                 uids);
     }
+
+    /**
+     * Get the global network bandwidth rate limit.
+     *
+     * The limit is only applicable to networks that provide internet connectivity. If the setting
+     * is unset, it defaults to -1.
+     *
+     * @param context The {@link Context} to query the setting.
+     * @return The rate limit in number of bytes per second or -1 if disabled.
+     */
+    public static long getIngressRateLimitInBytesPerSecond(@NonNull Context context) {
+        return Settings.Global.getLong(context.getContentResolver(),
+                INGRESS_RATE_LIMIT_BYTES_PER_SECOND, -1);
+    }
+
+    /**
+     * Set the global network bandwidth rate limit.
+     *
+     * The limit is only applicable to networks that provide internet connectivity.
+     *
+     * @param context The {@link Context} to set the setting.
+     * @param rateLimitInBytesPerSec The rate limit in number of bytes per second or -1 to disable.
+     */
+    public static void setIngressRateLimitInBytesPerSecond(@NonNull Context context,
+            @IntRange(from = -1, to = Integer.MAX_VALUE) long rateLimitInBytesPerSec) {
+        if (rateLimitInBytesPerSec < -1) {
+            throw new IllegalArgumentException(
+                    "Rate limit must be within the range [-1, Integer.MAX_VALUE]");
+        }
+        Settings.Global.putLong(context.getContentResolver(),
+                INGRESS_RATE_LIMIT_BYTES_PER_SECOND,
+                rateLimitInBytesPerSec);
+    }
 }
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 2b70934..7bb4529 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -90,6 +90,7 @@
 import static android.os.Process.INVALID_UID;
 import static android.os.Process.VPN_UID;
 import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
+import static android.system.OsConstants.ETH_P_ALL;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 
@@ -197,6 +198,7 @@
 import android.net.resolv.aidl.Nat64PrefixEventParcel;
 import android.net.resolv.aidl.PrivateDnsValidationEventParcel;
 import android.net.shared.PrivateDnsConfig;
+import android.net.util.InterfaceParams;
 import android.net.util.MultinetworkPolicyTracker;
 import android.os.BatteryStatsManager;
 import android.os.Binder;
@@ -248,6 +250,7 @@
 import com.android.net.module.util.LocationPermissionChecker;
 import com.android.net.module.util.NetworkCapabilitiesUtils;
 import com.android.net.module.util.PermissionUtils;
+import com.android.net.module.util.TcUtils;
 import com.android.net.module.util.netlink.InetDiagMessage;
 import com.android.server.connectivity.AutodestructReference;
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator;
@@ -274,6 +277,7 @@
 import libcore.io.IoUtils;
 
 import java.io.FileDescriptor;
+import java.io.IOException;
 import java.io.PrintWriter;
 import java.io.Writer;
 import java.net.Inet4Address;
@@ -709,6 +713,11 @@
     private static final int EVENT_SET_TEST_ALLOW_BAD_WIFI_UNTIL = 55;
 
     /**
+     * Used internally when INGRESS_RATE_LIMIT_BYTES_PER_SECOND setting changes.
+     */
+    private static final int EVENT_INGRESS_RATE_LIMIT_CHANGED = 56;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -725,6 +734,18 @@
      */
     private static final long MAX_TEST_ALLOW_BAD_WIFI_UNTIL_MS = 5 * 60 * 1000L;
 
+    /**
+     * The priority of the tc police rate limiter -- smaller value is higher priority.
+     * This value needs to be coordinated with PRIO_CLAT, PRIO_TETHER4, and PRIO_TETHER6.
+     */
+    private static final short TC_PRIO_POLICE = 1;
+
+    /**
+     * The BPF program attached to the tc-police hook to account for to-be-dropped traffic.
+     */
+    private static final String TC_POLICE_BPF_PROG_PATH =
+            "/sys/fs/bpf/prog_netd_schedact_ingress_account";
+
     private static String eventName(int what) {
         return sMagicDecoderRing.get(what, Integer.toString(what));
     }
@@ -815,6 +836,12 @@
     final Map<IBinder, ConnectivityDiagnosticsCallbackInfo> mConnectivityDiagnosticsCallbacks =
             new HashMap<>();
 
+    // Rate limit applicable to all internet capable networks (-1 = disabled). This value is
+    // configured via {@link
+    // ConnectivitySettingsManager#INGRESS_RATE_LIMIT_BYTES_PER_SECOND}
+    // Only the handler thread is allowed to access this field.
+    private long mIngressRateLimit = -1;
+
     /**
      * Implements support for the legacy "one network per network type" model.
      *
@@ -1367,6 +1394,48 @@
         public BpfNetMaps getBpfNetMaps(INetd netd) {
             return new BpfNetMaps(netd);
         }
+
+        /**
+         * Wraps {@link TcUtils#tcFilterAddDevIngressPolice}
+         */
+        public void enableIngressRateLimit(String iface, long rateInBytesPerSecond) {
+            final InterfaceParams params = InterfaceParams.getByName(iface);
+            if (params == null) {
+                // the interface might have disappeared.
+                logw("Failed to get interface params for interface " + iface);
+                return;
+            }
+            try {
+                // converting rateInBytesPerSecond from long to int is safe here because the
+                // setting's range is limited to INT_MAX.
+                // TODO: add long/uint64 support to tcFilterAddDevIngressPolice.
+                TcUtils.tcFilterAddDevIngressPolice(params.index, TC_PRIO_POLICE, (short) ETH_P_ALL,
+                        (int) rateInBytesPerSecond, TC_POLICE_BPF_PROG_PATH);
+            } catch (IOException e) {
+                loge("TcUtils.tcFilterAddDevIngressPolice(ifaceIndex=" + params.index
+                        + ", PRIO_POLICE, ETH_P_ALL, rateInBytesPerSecond="
+                        + rateInBytesPerSecond + ", bpfProgPath=" + TC_POLICE_BPF_PROG_PATH
+                        + ") failure: ", e);
+            }
+        }
+
+        /**
+         * Wraps {@link TcUtils#tcFilterDelDev}
+         */
+        public void disableIngressRateLimit(String iface) {
+            final InterfaceParams params = InterfaceParams.getByName(iface);
+            if (params == null) {
+                // the interface might have disappeared.
+                logw("Failed to get interface params for interface " + iface);
+                return;
+            }
+            try {
+                TcUtils.tcFilterDelDev(params.index, true, TC_PRIO_POLICE, (short) ETH_P_ALL);
+            } catch (IOException e) {
+                loge("TcUtils.tcFilterDelDev(ifaceIndex=" + params.index
+                        + ", ingress=true, PRIO_POLICE, ETH_P_ALL) failure: ", e);
+            }
+        }
     }
 
     public ConnectivityService(Context context) {
@@ -1541,6 +1610,9 @@
         } catch (ErrnoException e) {
             loge("Unable to create DscpPolicyTracker");
         }
+
+        mIngressRateLimit = ConnectivitySettingsManager.getIngressRateLimitInBytesPerSecond(
+                mContext);
     }
 
     private static NetworkCapabilities createDefaultNetworkCapabilitiesForUid(int uid) {
@@ -1611,6 +1683,11 @@
         mHandler.sendEmptyMessage(EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED);
     }
 
+    @VisibleForTesting
+    void updateIngressRateLimit() {
+        mHandler.sendEmptyMessage(EVENT_INGRESS_RATE_LIMIT_CHANGED);
+    }
+
     private void handleAlwaysOnNetworkRequest(NetworkRequest networkRequest, int id) {
         final boolean enable = mContext.getResources().getBoolean(id);
         handleAlwaysOnNetworkRequest(networkRequest, enable);
@@ -1672,6 +1749,12 @@
         mSettingsObserver.observe(
                 Settings.Secure.getUriFor(ConnectivitySettingsManager.MOBILE_DATA_PREFERRED_UIDS),
                 EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED);
+
+        // Watch for ingress rate limit changes.
+        mSettingsObserver.observe(
+                Settings.Secure.getUriFor(
+                        ConnectivitySettingsManager.INGRESS_RATE_LIMIT_BYTES_PER_SECOND),
+                EVENT_INGRESS_RATE_LIMIT_CHANGED);
     }
 
     private void registerPrivateDnsSettingsCallbacks() {
@@ -4091,6 +4174,11 @@
             // for an unnecessarily long time.
             destroyNativeNetwork(nai);
             mDnsManager.removeNetwork(nai.network);
+
+            // clean up tc police filters on interface.
+            if (canNetworkBeRateLimited(nai) && mIngressRateLimit >= 0) {
+                mDeps.disableIngressRateLimit(nai.linkProperties.getInterfaceName());
+            }
         }
         mNetIdManager.releaseNetId(nai.network.getNetId());
         nai.onNetworkDestroyed();
@@ -5158,6 +5246,9 @@
                     final long timeMs = ((Long) msg.obj).longValue();
                     mMultinetworkPolicyTracker.setTestAllowBadWifiUntil(timeMs);
                     break;
+                case EVENT_INGRESS_RATE_LIMIT_CHANGED:
+                    handleIngressRateLimitChanged();
+                    break;
             }
         }
     }
@@ -8854,6 +8945,19 @@
             // A network that has just connected has zero requests and is thus a foreground network.
             networkAgent.networkCapabilities.addCapability(NET_CAPABILITY_FOREGROUND);
 
+            // If a rate limit has been configured and is applicable to this network (network
+            // provides internet connectivity), apply it.
+            // Note: in case of a system server crash, there is a very small chance that this
+            // leaves some interfaces rate limited (i.e. if the rate limit had been changed just
+            // before the crash and was never applied). One solution would be to delete all
+            // potential tc police filters every time this is called. Since this is an unlikely
+            // scenario in the first place (and worst case, the interface stays rate limited until
+            // the device is rebooted), this seems a little overkill.
+            if (canNetworkBeRateLimited(networkAgent) && mIngressRateLimit >= 0) {
+                mDeps.enableIngressRateLimit(networkAgent.linkProperties.getInterfaceName(),
+                        mIngressRateLimit);
+            }
+
             if (!createNativeNetwork(networkAgent)) return;
             if (networkAgent.propagateUnderlyingCapabilities()) {
                 // Initialize the network's capabilities to their starting values according to the
@@ -10514,6 +10618,39 @@
         rematchAllNetworksAndRequests();
     }
 
+    private void handleIngressRateLimitChanged() {
+        final long oldIngressRateLimit = mIngressRateLimit;
+        mIngressRateLimit = ConnectivitySettingsManager.getIngressRateLimitInBytesPerSecond(
+                mContext);
+        for (final NetworkAgentInfo networkAgent : mNetworkAgentInfos) {
+            if (canNetworkBeRateLimited(networkAgent)) {
+                // If rate limit has previously been enabled, remove the old limit first.
+                if (oldIngressRateLimit >= 0) {
+                    mDeps.disableIngressRateLimit(networkAgent.linkProperties.getInterfaceName());
+                }
+                if (mIngressRateLimit >= 0) {
+                    mDeps.enableIngressRateLimit(networkAgent.linkProperties.getInterfaceName(),
+                            mIngressRateLimit);
+                }
+            }
+        }
+    }
+
+    private boolean canNetworkBeRateLimited(@NonNull final NetworkAgentInfo networkAgent) {
+        if (!networkAgent.networkCapabilities.hasCapability(NET_CAPABILITY_INTERNET)) {
+            // rate limits only apply to networks that provide internet connectivity.
+            return false;
+        }
+
+        final String iface = networkAgent.linkProperties.getInterfaceName();
+        if (iface == null) {
+            // This can never happen.
+            logwtf("canNetworkBeRateLimited: LinkProperties#getInterfaceName returns null");
+            return false;
+        }
+        return true;
+    }
+
     private void enforceAutomotiveDevice() {
         PermissionUtils.enforceSystemFeature(mContext, PackageManager.FEATURE_AUTOMOTIVE,
                 "setOemNetworkPreference() is only available on automotive devices.");
diff --git a/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt b/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt
index ebaa787..8d8958d 100644
--- a/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt
+++ b/tests/common/java/android/net/ConnectivitySettingsManagerTest.kt
@@ -39,6 +39,7 @@
 import android.net.ConnectivitySettingsManager.getDnsResolverSampleRanges
 import android.net.ConnectivitySettingsManager.getDnsResolverSampleValidityDuration
 import android.net.ConnectivitySettingsManager.getDnsResolverSuccessThresholdPercent
+import android.net.ConnectivitySettingsManager.getIngressRateLimitInBytesPerSecond
 import android.net.ConnectivitySettingsManager.getMobileDataActivityTimeout
 import android.net.ConnectivitySettingsManager.getMobileDataAlwaysOn
 import android.net.ConnectivitySettingsManager.getNetworkSwitchNotificationMaximumDailyCount
@@ -51,6 +52,7 @@
 import android.net.ConnectivitySettingsManager.setDnsResolverSampleRanges
 import android.net.ConnectivitySettingsManager.setDnsResolverSampleValidityDuration
 import android.net.ConnectivitySettingsManager.setDnsResolverSuccessThresholdPercent
+import android.net.ConnectivitySettingsManager.setIngressRateLimitInBytesPerSecond
 import android.net.ConnectivitySettingsManager.setMobileDataActivityTimeout
 import android.net.ConnectivitySettingsManager.setMobileDataAlwaysOn
 import android.net.ConnectivitySettingsManager.setNetworkSwitchNotificationMaximumDailyCount
@@ -292,4 +294,19 @@
                 setter = { setWifiAlwaysRequested(context, it) },
                 testIntValues = intArrayOf(0))
     }
+
+    @Test
+    fun testInternetNetworkRateLimitInBytesPerSecond() {
+        val defaultRate = getIngressRateLimitInBytesPerSecond(context)
+        val testRate = 1000L
+        setIngressRateLimitInBytesPerSecond(context, testRate)
+        assertEquals(testRate, getIngressRateLimitInBytesPerSecond(context))
+
+        setIngressRateLimitInBytesPerSecond(context, defaultRate)
+        assertEquals(defaultRate, getIngressRateLimitInBytesPerSecond(context))
+
+        assertFailsWith<IllegalArgumentException>("Expected failure, but setting accepted") {
+            setIngressRateLimitInBytesPerSecond(context, -10)
+        }
+    }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index e233e62..16b3d5a 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -179,7 +179,6 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.when;
 
 import static java.util.Arrays.asList;
 
@@ -388,6 +387,7 @@
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
@@ -1963,6 +1963,25 @@
         public BpfNetMaps getBpfNetMaps(INetd netd) {
             return mBpfNetMaps;
         }
+
+        final ArrayTrackRecord<Pair<String, Long>> mRateLimitHistory = new ArrayTrackRecord<>();
+        final Map<String, Long> mActiveRateLimit = new HashMap<>();
+
+        @Override
+        public void enableIngressRateLimit(final String iface, final long rateInBytesPerSecond) {
+            mRateLimitHistory.add(new Pair<>(iface, rateInBytesPerSecond));
+            // Due to a TC limitation, the rate limit needs to be removed before it can be
+            // updated. Check that this happened.
+            assertEquals(-1L, (long) mActiveRateLimit.getOrDefault(iface, -1L));
+            mActiveRateLimit.put(iface, rateInBytesPerSecond);
+        }
+
+        @Override
+        public void disableIngressRateLimit(final String iface) {
+            mRateLimitHistory.add(new Pair<>(iface, -1L));
+            assertNotEquals(-1L, (long) mActiveRateLimit.getOrDefault(iface, -1L));
+            mActiveRateLimit.put(iface, -1L);
+        }
     }
 
     private static void initAlarmManager(final AlarmManager am, final Handler alarmHandler) {
@@ -5027,6 +5046,13 @@
         waitForIdle();
     }
 
+    private void setIngressRateLimit(int rateLimitInBytesPerSec) {
+        ConnectivitySettingsManager.setIngressRateLimitInBytesPerSecond(mServiceContext,
+                rateLimitInBytesPerSec);
+        mService.updateIngressRateLimit();
+        waitForIdle();
+    }
+
     private boolean isForegroundNetwork(TestNetworkAgentWrapper network) {
         NetworkCapabilities nc = mCm.getNetworkCapabilities(network.getNetwork());
         assertNotNull(nc);
@@ -15339,4 +15365,153 @@
                 ConnectivityManager.TYPE_NONE, null /* hostAddress */, "com.not.package.owner",
                 null /* callingAttributionTag */));
     }
+
+    @Test
+    public void testUpdateRateLimit_EnableDisable() throws Exception {
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connect(true);
+
+        final LinkProperties cellLp = new LinkProperties();
+        cellLp.setInterfaceName(MOBILE_IFNAME);
+        mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mCellNetworkAgent.connect(false);
+
+        waitForIdle();
+
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHeadWifi =
+                mDeps.mRateLimitHistory.newReadHead();
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHeadCell =
+                mDeps.mRateLimitHistory.newReadHead();
+
+        // set rate limit to 8MBit/s => 1MB/s
+        final int rateLimitInBytesPerSec = 1 * 1000 * 1000;
+        setIngressRateLimit(rateLimitInBytesPerSec);
+
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName()
+                        && it.second == rateLimitInBytesPerSec));
+        assertNotNull(readHeadCell.poll(TIMEOUT_MS,
+                it -> it.first == cellLp.getInterfaceName()
+                        && it.second == rateLimitInBytesPerSec));
+
+        // disable rate limiting
+        setIngressRateLimit(-1);
+
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName() && it.second == -1));
+        assertNotNull(readHeadCell.poll(TIMEOUT_MS,
+                it -> it.first == cellLp.getInterfaceName() && it.second == -1));
+    }
+
+    @Test
+    public void testUpdateRateLimit_WhenNewNetworkIsAdded() throws Exception {
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connect(true);
+
+        waitForIdle();
+
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHead =
+                mDeps.mRateLimitHistory.newReadHead();
+
+        // set rate limit to 8MBit/s => 1MB/s
+        final int rateLimitInBytesPerSec = 1 * 1000 * 1000;
+        setIngressRateLimit(rateLimitInBytesPerSec);
+        assertNotNull(readHead.poll(TIMEOUT_MS, it -> it.first == wifiLp.getInterfaceName()
+                && it.second == rateLimitInBytesPerSec));
+
+        final LinkProperties cellLp = new LinkProperties();
+        cellLp.setInterfaceName(MOBILE_IFNAME);
+        mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mCellNetworkAgent.connect(false);
+        assertNotNull(readHead.poll(TIMEOUT_MS, it -> it.first == cellLp.getInterfaceName()
+                && it.second == rateLimitInBytesPerSec));
+    }
+
+    @Test
+    public void testUpdateRateLimit_OnlyAffectsInternetCapableNetworks() throws Exception {
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connectWithoutInternet();
+
+        waitForIdle();
+
+        setIngressRateLimit(1000);
+        setIngressRateLimit(-1);
+
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHeadWifi =
+                mDeps.mRateLimitHistory.newReadHead();
+        assertNull(readHeadWifi.poll(TIMEOUT_MS, it -> it.first == wifiLp.getInterfaceName()));
+    }
+
+    @Test
+    public void testUpdateRateLimit_DisconnectingResetsRateLimit()
+            throws Exception {
+        // Steps:
+        // - connect network
+        // - set rate limit
+        // - disconnect network (interface still exists)
+        // - disable rate limit
+        // - connect network
+        // - ensure network interface is not rate limited
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connect(true);
+        waitForIdle();
+
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHeadWifi =
+                mDeps.mRateLimitHistory.newReadHead();
+
+        int rateLimitInBytesPerSec = 1000;
+        setIngressRateLimit(rateLimitInBytesPerSec);
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName()
+                        && it.second == rateLimitInBytesPerSec));
+
+        mWiFiNetworkAgent.disconnect();
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName() && it.second == -1));
+
+        setIngressRateLimit(-1);
+
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connect(true);
+        assertNull(readHeadWifi.poll(TIMEOUT_MS, it -> it.first == wifiLp.getInterfaceName()));
+    }
+
+    @Test
+    public void testUpdateRateLimit_UpdateExistingRateLimit() throws Exception {
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiNetworkAgent.connect(true);
+        waitForIdle();
+
+        final ArrayTrackRecord<Pair<String, Long>>.ReadHead readHeadWifi =
+                mDeps.mRateLimitHistory.newReadHead();
+
+        // update an active ingress rate limit
+        setIngressRateLimit(1000);
+        setIngressRateLimit(2000);
+
+        // verify the following order of execution:
+        // 1. ingress rate limit set to 1000.
+        // 2. ingress rate limit disabled (triggered by updating active rate limit).
+        // 3. ingress rate limit set to 2000.
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName()
+                        && it.second == 1000));
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName()
+                        && it.second == -1));
+        assertNotNull(readHeadWifi.poll(TIMEOUT_MS,
+                it -> it.first == wifiLp.getInterfaceName()
+                        && it.second == 2000));
+    }
 }