Merge "Run Connectivity cts tests against unbundled version of BouncyCastle"
diff --git a/Tethering/Android.bp b/Tethering/Android.bp
index 7427481..73175580 100644
--- a/Tethering/Android.bp
+++ b/Tethering/Android.bp
@@ -62,7 +62,10 @@
         "com.android.tethering",
     ],
     min_sdk_version: "30",
-    header_libs: ["bpf_syscall_wrappers"],
+    header_libs: [
+        "bpf_syscall_wrappers",
+        "bpf_tethering_headers",
+    ],
     srcs: [
         "jni/*.cpp",
     ],
diff --git a/Tethering/bpf_progs/Android.bp b/Tethering/bpf_progs/Android.bp
index d62e9a1..b06528b 100644
--- a/Tethering/bpf_progs/Android.bp
+++ b/Tethering/bpf_progs/Android.bp
@@ -15,6 +15,26 @@
 //
 
 //
+// struct definitions shared with JNI
+//
+cc_library_headers {
+    name: "bpf_tethering_headers",
+    vendor_available: false,
+    host_supported: false,
+    export_include_dirs: ["."],
+    cflags: [
+        "-Wall",
+        "-Werror",
+    ],
+    sdk_version: "30",
+    min_sdk_version: "30",
+    apex_available: ["com.android.tethering"],
+    visibility: [
+        "//packages/modules/Connectivity/Tethering",
+    ],
+}
+
+//
 // bpf kernel programs
 //
 bpf {
diff --git a/Tethering/bpf_progs/bpf_tethering.h b/Tethering/bpf_progs/bpf_tethering.h
new file mode 100644
index 0000000..c8ada88
--- /dev/null
+++ b/Tethering/bpf_progs/bpf_tethering.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+// Common definitions for BPF code in the tethering mainline module.
+// These definitions are available to:
+// - The BPF programs in Tethering/bpf_progs/
+// - JNI code that depends on the bpf_tethering_headers library.
+
+#define BPF_TETHER_ERRORS    \
+    ERR(INVALID_IP_VERSION)  \
+    ERR(LOW_TTL)             \
+    ERR(INVALID_TCP_HEADER)  \
+    ERR(TCP_CONTROL_PACKET)  \
+    ERR(NON_GLOBAL_SRC)      \
+    ERR(NON_GLOBAL_DST)      \
+    ERR(LOCAL_SRC_DST)       \
+    ERR(NO_STATS_ENTRY)      \
+    ERR(NO_LIMIT_ENTRY)      \
+    ERR(BELOW_IPV4_MTU)      \
+    ERR(BELOW_IPV6_MTU)      \
+    ERR(LIMIT_REACHED)       \
+    ERR(CHANGE_HEAD_FAILED)  \
+    ERR(TOO_SHORT)           \
+    ERR(HAS_IP_OPTIONS)      \
+    ERR(IS_IP_FRAG)          \
+    ERR(CHECKSUM)            \
+    ERR(NON_TCP_UDP)         \
+    ERR(SHORT_TCP_HEADER)    \
+    ERR(SHORT_UDP_HEADER)    \
+    ERR(TRUNCATED_IPV4)      \
+    ERR(_MAX)
+
+#define ERR(x) BPF_TETHER_ERR_ ##x,
+enum {
+    BPF_TETHER_ERRORS
+};
+#undef ERR
+
+#define ERR(x) #x,
+static const char *bpf_tether_errors[] = {
+    BPF_TETHER_ERRORS
+};
+#undef ERR
diff --git a/Tethering/bpf_progs/offload.c b/Tethering/bpf_progs/offload.c
index 2997031..bf60e67 100644
--- a/Tethering/bpf_progs/offload.c
+++ b/Tethering/bpf_progs/offload.c
@@ -26,11 +26,50 @@
 
 #include "bpf_helpers.h"
 #include "bpf_net_helpers.h"
+#include "bpf_tethering.h"
 #include "netdbpf/bpf_shared.h"
 
 // From kernel:include/net/ip.h
 #define IP_DF 0x4000  // Flag: "Don't Fragment"
 
+// ----- Helper functions for offsets to fields -----
+
+// They all assume simple IP packets:
+//   - no VLAN ethernet tags
+//   - no IPv4 options (see IPV4_HLEN/TCP4_OFFSET/UDP4_OFFSET)
+//   - no IPv6 extension headers
+//   - no TCP options (see TCP_HLEN)
+
+//#define ETH_HLEN sizeof(struct ethhdr)
+#define IP4_HLEN sizeof(struct iphdr)
+#define IP6_HLEN sizeof(struct ipv6hdr)
+#define TCP_HLEN sizeof(struct tcphdr)
+#define UDP_HLEN sizeof(struct udphdr)
+
+// Offsets from beginning of L4 (TCP/UDP) header
+#define TCP_OFFSET(field) offsetof(struct tcphdr, field)
+#define UDP_OFFSET(field) offsetof(struct udphdr, field)
+
+// Offsets from beginning of L3 (IPv4) header
+#define IP4_OFFSET(field) offsetof(struct iphdr, field)
+#define IP4_TCP_OFFSET(field) (IP4_HLEN + TCP_OFFSET(field))
+#define IP4_UDP_OFFSET(field) (IP4_HLEN + UDP_OFFSET(field))
+
+// Offsets from beginning of L3 (IPv6) header
+#define IP6_OFFSET(field) offsetof(struct ipv6hdr, field)
+#define IP6_TCP_OFFSET(field) (IP6_HLEN + TCP_OFFSET(field))
+#define IP6_UDP_OFFSET(field) (IP6_HLEN + UDP_OFFSET(field))
+
+// Offsets from beginning of L2 (ie. Ethernet) header (which must be present)
+#define ETH_IP4_OFFSET(field) (ETH_HLEN + IP4_OFFSET(field))
+#define ETH_IP4_TCP_OFFSET(field) (ETH_HLEN + IP4_TCP_OFFSET(field))
+#define ETH_IP4_UDP_OFFSET(field) (ETH_HLEN + IP4_UDP_OFFSET(field))
+#define ETH_IP6_OFFSET(field) (ETH_HLEN + IP6_OFFSET(field))
+#define ETH_IP6_TCP_OFFSET(field) (ETH_HLEN + IP6_TCP_OFFSET(field))
+#define ETH_IP6_UDP_OFFSET(field) (ETH_HLEN + IP6_UDP_OFFSET(field))
+
+// ----- Tethering stats and data limits -----
+
 // Tethering stats, indexed by upstream interface.
 DEFINE_BPF_MAP_GRW(tether_stats_map, HASH, TetherStatsKey, TetherStatsValue, 16, AID_NETWORK_STACK)
 
@@ -49,6 +88,19 @@
 DEFINE_BPF_MAP_GRW(tether_upstream6_map, HASH, TetherUpstream6Key, Tether6Value, 64,
                    AID_NETWORK_STACK)
 
+DEFINE_BPF_MAP_GRW(tether_error_map, ARRAY, __u32, __u32, BPF_TETHER_ERR__MAX,
+                   AID_NETWORK_STACK)
+
+#define COUNT_AND_RETURN(counter, ret) do {                    \
+    __u32 code = BPF_TETHER_ERR_ ## counter;                 \
+    __u32 *count = bpf_tether_error_map_lookup_elem(&code);  \
+    if (count) __sync_fetch_and_add(count, 1);               \
+    return ret;                                              \
+} while(0)
+
+#define DROP(counter) COUNT_AND_RETURN(counter, TC_ACT_SHOT)
+#define PUNT(counter) COUNT_AND_RETURN(counter, TC_ACT_OK)
+
 static inline __always_inline int do_forward6(struct __sk_buff* skb, const bool is_ethernet,
         const bool downstream) {
     const int l2_header_size = is_ethernet ? sizeof(struct ethhdr) : 0;
@@ -70,11 +122,11 @@
     if (is_ethernet && (eth->h_proto != htons(ETH_P_IPV6))) return TC_ACT_OK;
 
     // IP version must be 6
-    if (ip6->version != 6) return TC_ACT_OK;
+    if (ip6->version != 6) PUNT(INVALID_IP_VERSION);
 
     // Cannot decrement during forward if already zero or would be zero,
     // Let the kernel's stack handle these cases and generate appropriate ICMP errors.
-    if (ip6->hop_limit <= 1) return TC_ACT_OK;
+    if (ip6->hop_limit <= 1) PUNT(LOW_TTL);
 
     // If hardware offload is running and programming flows based on conntrack entries,
     // try not to interfere with it.
@@ -82,27 +134,28 @@
         struct tcphdr* tcph = (void*)(ip6 + 1);
 
         // Make sure we can get at the tcp header
-        if (data + l2_header_size + sizeof(*ip6) + sizeof(*tcph) > data_end) return TC_ACT_OK;
+        if (data + l2_header_size + sizeof(*ip6) + sizeof(*tcph) > data_end)
+            PUNT(INVALID_TCP_HEADER);
 
         // Do not offload TCP packets with any one of the SYN/FIN/RST flags
-        if (tcph->syn || tcph->fin || tcph->rst) return TC_ACT_OK;
+        if (tcph->syn || tcph->fin || tcph->rst) PUNT(TCP_CONTROL_PACKET);
     }
 
     // Protect against forwarding packets sourced from ::1 or fe80::/64 or other weirdness.
     __be32 src32 = ip6->saddr.s6_addr32[0];
     if (src32 != htonl(0x0064ff9b) &&                        // 64:ff9b:/32 incl. XLAT464 WKP
         (src32 & htonl(0xe0000000)) != htonl(0x20000000))    // 2000::/3 Global Unicast
-        return TC_ACT_OK;
+        PUNT(NON_GLOBAL_SRC);
 
     // Protect against forwarding packets destined to ::1 or fe80::/64 or other weirdness.
     __be32 dst32 = ip6->daddr.s6_addr32[0];
     if (dst32 != htonl(0x0064ff9b) &&                        // 64:ff9b:/32 incl. XLAT464 WKP
         (dst32 & htonl(0xe0000000)) != htonl(0x20000000))    // 2000::/3 Global Unicast
-        return TC_ACT_OK;
+        PUNT(NON_GLOBAL_DST);
 
     // In the upstream direction do not forward traffic within the same /64 subnet.
     if (!downstream && (src32 == dst32) && (ip6->saddr.s6_addr32[1] == ip6->daddr.s6_addr32[1]))
-        return TC_ACT_OK;
+        PUNT(LOCAL_SRC_DST);
 
     TetherDownstream6Key kd = {
             .iif = skb->ifindex,
@@ -124,15 +177,15 @@
     TetherStatsValue* stat_v = bpf_tether_stats_map_lookup_elem(&stat_and_limit_k);
 
     // If we don't have anywhere to put stats, then abort...
-    if (!stat_v) return TC_ACT_OK;
+    if (!stat_v) PUNT(NO_STATS_ENTRY);
 
     uint64_t* limit_v = bpf_tether_limit_map_lookup_elem(&stat_and_limit_k);
 
     // If we don't have a limit, then abort...
-    if (!limit_v) return TC_ACT_OK;
+    if (!limit_v) PUNT(NO_LIMIT_ENTRY);
 
     // Required IPv6 minimum mtu is 1280, below that not clear what we should do, abort...
-    if (v->pmtu < IPV6_MIN_MTU) return TC_ACT_OK;
+    if (v->pmtu < IPV6_MIN_MTU) PUNT(BELOW_IPV6_MTU);
 
     // Approximate handling of TCP/IPv6 overhead for incoming LRO/GRO packets: default
     // outbound path mtu of 1500 is not necessarily correct, but worst case we simply
@@ -157,7 +210,7 @@
     // a packet we let the core stack deal with things.
     // (The core stack needs to handle limits correctly anyway,
     // since we don't offload all traffic in both directions)
-    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) return TC_ACT_OK;
+    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) PUNT(LIMIT_REACHED);
 
     if (!is_ethernet) {
         // Try to inject an ethernet header, and simply return if we fail.
@@ -165,7 +218,7 @@
         // because this is easier and the kernel will strip extraneous ethernet header.
         if (bpf_skb_change_head(skb, sizeof(struct ethhdr), /*flags*/ 0)) {
             __sync_fetch_and_add(downstream ? &stat_v->rxErrors : &stat_v->txErrors, 1);
-            return TC_ACT_OK;
+            PUNT(CHANGE_HEAD_FAILED);
         }
 
         // bpf_skb_change_head() invalidates all pointers - reload them
@@ -177,7 +230,7 @@
         // I do not believe this can ever happen, but keep the verifier happy...
         if (data + sizeof(struct ethhdr) + sizeof(*ip6) > data_end) {
             __sync_fetch_and_add(downstream ? &stat_v->rxErrors : &stat_v->txErrors, 1);
-            return TC_ACT_SHOT;
+            DROP(TOO_SHORT);
         }
     };
 
@@ -308,10 +361,10 @@
     if (is_ethernet && (eth->h_proto != htons(ETH_P_IP))) return TC_ACT_OK;
 
     // IP version must be 4
-    if (ip->version != 4) return TC_ACT_OK;
+    if (ip->version != 4) PUNT(INVALID_IP_VERSION);
 
     // We cannot handle IP options, just standard 20 byte == 5 dword minimal IPv4 header
-    if (ip->ihl != 5) return TC_ACT_OK;
+    if (ip->ihl != 5) PUNT(HAS_IP_OPTIONS);
 
     // Calculate the IPv4 one's complement checksum of the IPv4 header.
     __wsum sum4 = 0;
@@ -322,36 +375,36 @@
     sum4 = (sum4 & 0xFFFF) + (sum4 >> 16);  // collapse u32 into range 1 .. 0x1FFFE
     sum4 = (sum4 & 0xFFFF) + (sum4 >> 16);  // collapse any potential carry into u16
     // for a correct checksum we should get *a* zero, but sum4 must be positive, ie 0xFFFF
-    if (sum4 != 0xFFFF) return TC_ACT_OK;
+    if (sum4 != 0xFFFF) PUNT(CHECKSUM);
 
     // Minimum IPv4 total length is the size of the header
-    if (ntohs(ip->tot_len) < sizeof(*ip)) return TC_ACT_OK;
+    if (ntohs(ip->tot_len) < sizeof(*ip)) PUNT(TRUNCATED_IPV4);
 
     // We are incapable of dealing with IPv4 fragments
-    if (ip->frag_off & ~htons(IP_DF)) return TC_ACT_OK;
+    if (ip->frag_off & ~htons(IP_DF)) PUNT(IS_IP_FRAG);
 
     // Cannot decrement during forward if already zero or would be zero,
     // Let the kernel's stack handle these cases and generate appropriate ICMP errors.
-    if (ip->ttl <= 1) return TC_ACT_OK;
+    if (ip->ttl <= 1) PUNT(LOW_TTL);
 
     const bool is_tcp = (ip->protocol == IPPROTO_TCP);
 
     // We do not support anything besides TCP and UDP
-    if (!is_tcp && (ip->protocol != IPPROTO_UDP)) return TC_ACT_OK;
+    if (!is_tcp && (ip->protocol != IPPROTO_UDP)) PUNT(NON_TCP_UDP);
 
     struct tcphdr* tcph = is_tcp ? (void*)(ip + 1) : NULL;
     struct udphdr* udph = is_tcp ? NULL : (void*)(ip + 1);
 
     if (is_tcp) {
         // Make sure we can get at the tcp header
-        if (data + l2_header_size + sizeof(*ip) + sizeof(*tcph) > data_end) return TC_ACT_OK;
+        if (data + l2_header_size + sizeof(*ip) + sizeof(*tcph) > data_end) PUNT(SHORT_TCP_HEADER);
 
         // If hardware offload is running and programming flows based on conntrack entries, try not
         // to interfere with it, so do not offload TCP packets with any one of the SYN/FIN/RST flags
-        if (tcph->syn || tcph->fin || tcph->rst) return TC_ACT_OK;
+        if (tcph->syn || tcph->fin || tcph->rst) PUNT(TCP_CONTROL_PACKET);
     } else { // UDP
         // Make sure we can get at the udp header
-        if (data + l2_header_size + sizeof(*ip) + sizeof(*udph) > data_end) return TC_ACT_OK;
+        if (data + l2_header_size + sizeof(*ip) + sizeof(*udph) > data_end) PUNT(SHORT_UDP_HEADER);
     }
 
     Tether4Key k = {
@@ -375,15 +428,15 @@
     TetherStatsValue* stat_v = bpf_tether_stats_map_lookup_elem(&stat_and_limit_k);
 
     // If we don't have anywhere to put stats, then abort...
-    if (!stat_v) return TC_ACT_OK;
+    if (!stat_v) PUNT(NO_STATS_ENTRY);
 
     uint64_t* limit_v = bpf_tether_limit_map_lookup_elem(&stat_and_limit_k);
 
     // If we don't have a limit, then abort...
-    if (!limit_v) return TC_ACT_OK;
+    if (!limit_v) PUNT(NO_LIMIT_ENTRY);
 
     // Required IPv4 minimum mtu is 68, below that not clear what we should do, abort...
-    if (v->pmtu < 68) return TC_ACT_OK;
+    if (v->pmtu < 68) PUNT(BELOW_IPV4_MTU);
 
     // Approximate handling of TCP/IPv4 overhead for incoming LRO/GRO packets: default
     // outbound path mtu of 1500 is not necessarily correct, but worst case we simply
@@ -408,14 +461,78 @@
     // a packet we let the core stack deal with things.
     // (The core stack needs to handle limits correctly anyway,
     // since we don't offload all traffic in both directions)
-    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) return TC_ACT_OK;
+    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) PUNT(LIMIT_REACHED);
 
-    // TODO: replace Errors with Packets once implemented
-    __sync_fetch_and_add(downstream ? &stat_v->rxErrors : &stat_v->txErrors, packets);
+
+if (!is_tcp) return TC_ACT_OK; // HACK
+
+    if (!is_ethernet) {
+        // Try to inject an ethernet header, and simply return if we fail.
+        // We do this even if TX interface is RAWIP and thus does not need an ethernet header,
+        // because this is easier and the kernel will strip extraneous ethernet header.
+        if (bpf_skb_change_head(skb, sizeof(struct ethhdr), /*flags*/ 0)) {
+            __sync_fetch_and_add(downstream ? &stat_v->rxErrors : &stat_v->txErrors, 1);
+            PUNT(CHANGE_HEAD_FAILED);
+        }
+
+        // bpf_skb_change_head() invalidates all pointers - reload them
+        data = (void*)(long)skb->data;
+        data_end = (void*)(long)skb->data_end;
+        eth = data;
+        ip = (void*)(eth + 1);
+        tcph = is_tcp ? (void*)(ip + 1) : NULL;
+        udph = is_tcp ? NULL : (void*)(ip + 1);
+
+        // I do not believe this can ever happen, but keep the verifier happy...
+        if (data + sizeof(struct ethhdr) + sizeof(*ip) + (is_tcp ? sizeof(*tcph) : sizeof(*udph)) > data_end) {
+            __sync_fetch_and_add(downstream ? &stat_v->rxErrors : &stat_v->txErrors, 1);
+            DROP(TOO_SHORT);
+        }
+    };
+
+    // At this point we always have an ethernet header - which will get stripped by the
+    // kernel during transmit through a rawip interface.  ie. 'eth' pointer is valid.
+    // Additionally note that 'is_ethernet' and 'l2_header_size' are no longer correct.
+
+    // Overwrite any mac header with the new one
+    // For a rawip tx interface it will simply be a bunch of zeroes and later stripped.
+    *eth = v->macHeader;
+
+    const int sz4 = sizeof(__be32);
+    const __be32 old_daddr = k.dst4.s_addr;
+    const __be32 old_saddr = k.src4.s_addr;
+    const __be32 new_daddr = v->dst46.s6_addr32[3];
+    const __be32 new_saddr = v->src46.s6_addr32[3];
+
+    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), old_daddr, new_daddr, sz4 | BPF_F_PSEUDO_HDR);
+    bpf_l3_csum_replace(skb, ETH_IP4_OFFSET(check), old_daddr, new_daddr, sz4);
+    bpf_skb_store_bytes(skb, ETH_IP4_OFFSET(daddr), &new_daddr, sz4, 0);
+
+    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), old_saddr, new_saddr, sz4 | BPF_F_PSEUDO_HDR);
+    bpf_l3_csum_replace(skb, ETH_IP4_OFFSET(check), old_saddr, new_saddr, sz4);
+    bpf_skb_store_bytes(skb, ETH_IP4_OFFSET(saddr), &new_saddr, sz4, 0);
+
+    const int sz2 = sizeof(__be16);
+    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), k.srcPort, v->srcPort, sz2);
+    bpf_skb_store_bytes(skb, ETH_IP4_TCP_OFFSET(source), &v->srcPort, sz2, 0);
+
+    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), k.dstPort, v->dstPort, sz2);
+    bpf_skb_store_bytes(skb, ETH_IP4_TCP_OFFSET(dest), &v->dstPort, sz2, 0);
+
+// TTL dec
+
+// v->last_used = bpf_ktime_get_boot_ns();
+
+    __sync_fetch_and_add(downstream ? &stat_v->rxPackets : &stat_v->txPackets, packets);
     __sync_fetch_and_add(downstream ? &stat_v->rxBytes : &stat_v->txBytes, bytes);
 
-    // TODO: not actually implemented yet
-    return TC_ACT_OK;
+    // Redirect to forwarded interface.
+    //
+    // Note that bpf_redirect() cannot fail unless you pass invalid flags.
+    // The redirect actually happens after the ebpf program has already terminated,
+    // and can fail for example for mtu reasons at that point in time, but there's nothing
+    // we can do about it here.
+    return bpf_redirect(v->oif, 0 /* this is effectively BPF_F_EGRESS */);
 }
 
 // Real implementations for 5.9+ kernels
diff --git a/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp b/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp
new file mode 100644
index 0000000..27357f8
--- /dev/null
+++ b/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <jni.h>
+#include <nativehelper/JNIHelp.h>
+
+#include "bpf_tethering.h"
+
+namespace android {
+
+static jobjectArray getBpfCounterNames(JNIEnv *env) {
+    size_t size = BPF_TETHER_ERR__MAX;
+    jobjectArray ret = env->NewObjectArray(size, env->FindClass("java/lang/String"), nullptr);
+    for (int i = 0; i < size; i++) {
+        env->SetObjectArrayElement(ret, i, env->NewStringUTF(bpf_tether_errors[i]));
+    }
+    return ret;
+}
+
+/*
+ * JNI registration.
+ */
+static const JNINativeMethod gMethods[] = {
+    /* name, signature, funcPtr */
+    { "getBpfCounterNames", "()[Ljava/lang/String;", (void*) getBpfCounterNames },
+};
+
+int register_com_android_networkstack_tethering_BpfCoordinator(JNIEnv* env) {
+    return jniRegisterNativeMethods(env,
+            "com/android/networkstack/tethering/BpfCoordinator",
+            gMethods, NELEM(gMethods));
+}
+
+}; // namespace android
diff --git a/Tethering/jni/onload.cpp b/Tethering/jni/onload.cpp
index 3766de9..e31da60 100644
--- a/Tethering/jni/onload.cpp
+++ b/Tethering/jni/onload.cpp
@@ -24,6 +24,7 @@
 
 int register_android_net_util_TetheringUtils(JNIEnv* env);
 int register_com_android_networkstack_tethering_BpfMap(JNIEnv* env);
+int register_com_android_networkstack_tethering_BpfCoordinator(JNIEnv* env);
 
 extern "C" jint JNI_OnLoad(JavaVM* vm, void*) {
     JNIEnv *env;
@@ -36,6 +37,8 @@
 
     if (register_com_android_networkstack_tethering_BpfMap(env) < 0) return JNI_ERR;
 
+    if (register_com_android_networkstack_tethering_BpfCoordinator(env) < 0) return JNI_ERR;
+
     return JNI_VERSION_1_6;
 }
 
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index b17bfcf..985328f 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -59,6 +59,7 @@
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.NetworkStackConstants;
+import com.android.net.module.util.Struct;
 import com.android.networkstack.tethering.apishim.common.BpfCoordinatorShim;
 
 import java.net.Inet4Address;
@@ -84,6 +85,13 @@
  * @hide
  */
 public class BpfCoordinator {
+    // Ensure the JNI code is loaded. In production this will already have been loaded by
+    // TetherService, but for tests it needs to be either loaded here or loaded by every test.
+    // TODO: is there a better way?
+    static {
+        System.loadLibrary("tetherutilsjni");
+    }
+
     static final boolean DOWNSTREAM = true;
     static final boolean UPSTREAM = false;
 
@@ -97,6 +105,10 @@
     private static final String TETHER_UPSTREAM6_FS_PATH = makeMapPath(UPSTREAM, 6);
     private static final String TETHER_STATS_MAP_PATH = makeMapPath("stats");
     private static final String TETHER_LIMIT_MAP_PATH = makeMapPath("limit");
+    private static final String TETHER_ERROR_MAP_PATH = makeMapPath("error");
+
+    /** The names of all the BPF counters defined in bpf_tethering.h. */
+    public static final String[] sBpfCounterNames = getBpfCounterNames();
 
     private static String makeMapPath(String which) {
         return "/sys/fs/bpf/tethering/map_offload_tether_" + which + "_map";
@@ -717,6 +729,12 @@
             dumpIpv4ForwardingRules(pw);
             pw.decreaseIndent();
 
+            pw.println();
+            pw.println("Forwarding counters:");
+            pw.increaseIndent();
+            dumpCounters(pw);
+            pw.decreaseIndent();
+
             dumpDone.open();
         });
         if (!dumpDone.block(DUMP_TIMEOUT_MS)) {
@@ -766,17 +784,16 @@
     }
 
     private void dumpIpv6UpstreamRules(IndentingPrintWriter pw) {
-        final BpfMap<TetherUpstream6Key, Tether6Value> ipv6UpstreamMap = mDeps.getBpfUpstream6Map();
-        if (ipv6UpstreamMap == null) {
-            pw.println("No IPv6 upstream");
-            return;
-        }
-        try {
-            if (ipv6UpstreamMap.isEmpty()) {
+        try (BpfMap<TetherUpstream6Key, Tether6Value> map = mDeps.getBpfUpstream6Map()) {
+            if (map == null) {
+                pw.println("No IPv6 upstream");
+                return;
+            }
+            if (map.isEmpty()) {
                 pw.println("No IPv6 upstream rules");
                 return;
             }
-            ipv6UpstreamMap.forEach((k, v) -> pw.println(ipv6UpstreamRuletoString(k, v)));
+            map.forEach((k, v) -> pw.println(ipv6UpstreamRuletoString(k, v)));
         } catch (ErrnoException e) {
             pw.println("Error dumping IPv4 map: " + e);
         }
@@ -797,25 +814,54 @@
     }
 
     private void dumpIpv4ForwardingRules(IndentingPrintWriter pw) {
-        final BpfMap<Tether4Key, Tether4Value> ipv4UpstreamMap = mDeps.getBpfUpstream4Map();
-        if (ipv4UpstreamMap == null) {
-            pw.println("No IPv4 support");
-            return;
-        }
-        try {
-            if (ipv4UpstreamMap.isEmpty()) {
+        try (BpfMap<Tether4Key, Tether4Value> map = mDeps.getBpfUpstream4Map()) {
+            if (map == null) {
+                pw.println("No IPv4 support");
+                return;
+            }
+            if (map.isEmpty()) {
                 pw.println("No IPv4 rules");
                 return;
             }
             pw.println("[IPv4]: iif(iface) oif(iface) src nat dst");
             pw.increaseIndent();
-            ipv4UpstreamMap.forEach((k, v) -> pw.println(ipv4RuleToString(k, v)));
+            map.forEach((k, v) -> pw.println(ipv4RuleToString(k, v)));
         } catch (ErrnoException e) {
             pw.println("Error dumping IPv4 map: " + e);
         }
         pw.decreaseIndent();
     }
 
+    /**
+     * Simple struct that only contains a u32. Must be public because Struct needs access to it.
+     * TODO: make this a public inner class of Struct so anyone can use it as, e.g., Struct.U32?
+     */
+    public static class U32Struct extends Struct {
+        @Struct.Field(order = 0, type = Struct.Type.U32)
+        public long val;
+    }
+
+    private void dumpCounters(@NonNull IndentingPrintWriter pw) {
+        try (BpfMap<U32Struct, U32Struct> map = new BpfMap<>(TETHER_ERROR_MAP_PATH,
+                BpfMap.BPF_F_RDONLY, U32Struct.class, U32Struct.class)) {
+
+            map.forEach((k, v) -> {
+                String counterName;
+                try {
+                    counterName = sBpfCounterNames[(int) k.val];
+                } catch (IndexOutOfBoundsException e) {
+                    // Should never happen because this code gets the counter name from the same
+                    // include file as the BPF program that increments the counter.
+                    Log.wtf(TAG, "Unknown tethering counter type " + k.val);
+                    counterName = Long.toString(k.val);
+                }
+                if (v.val > 0) pw.println(String.format("%s: %d", counterName, v.val));
+            });
+        } catch (ErrnoException e) {
+            pw.println("Error dumping counter map: " + e);
+        }
+    }
+
     /** IPv6 forwarding rule class. */
     public static class Ipv6ForwardingRule {
         public final int upstreamIfindex;
@@ -1298,4 +1344,6 @@
     final SparseArray<String> getInterfaceNamesForTesting() {
         return mInterfaceNames;
     }
+
+    private static native String[] getBpfCounterNames();
 }
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfMap.java b/Tethering/src/com/android/networkstack/tethering/BpfMap.java
index 89caa8a..bc01dbd 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfMap.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfMap.java
@@ -218,7 +218,7 @@
     }
 
     @Override
-    public void close() throws Exception {
+    public void close() throws ErrnoException {
         closeMap(mMapFd);
     }
 
diff --git a/Tethering/tests/unit/Android.bp b/Tethering/tests/unit/Android.bp
index 5e4fe52..6c479a0 100644
--- a/Tethering/tests/unit/Android.bp
+++ b/Tethering/tests/unit/Android.bp
@@ -69,6 +69,7 @@
         // For mockito extended
         "libdexmakerjvmtiagent",
         "libstaticjvmtiagent",
+        "libtetherutilsjni",
     ],
 }
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
index 8fadf9e..5c99c67 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MeterednessConfigurationRule.java
@@ -15,21 +15,20 @@
  */
 package com.android.cts.net.hostside;
 
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.resetMeteredNetwork;
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setupMeteredNetwork;
+import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setupActiveNetworkMeteredness;
 import static com.android.cts.net.hostside.Property.METERED_NETWORK;
 import static com.android.cts.net.hostside.Property.NON_METERED_NETWORK;
 
 import android.util.ArraySet;
-import android.util.Pair;
 
 import com.android.compatibility.common.util.BeforeAfterRule;
+import com.android.compatibility.common.util.ThrowingRunnable;
 
 import org.junit.runner.Description;
 import org.junit.runners.model.Statement;
 
 public class MeterednessConfigurationRule extends BeforeAfterRule {
-    private Pair<String, Boolean> mSsidAndInitialMeteredness;
+    private ThrowingRunnable mMeterednessResetter;
 
     @Override
     public void onBefore(Statement base, Description description) throws Throwable {
@@ -48,13 +47,13 @@
     }
 
     public void configureNetworkMeteredness(boolean metered) throws Exception {
-        mSsidAndInitialMeteredness = setupMeteredNetwork(metered);
+        mMeterednessResetter = setupActiveNetworkMeteredness(metered);
     }
 
     public void resetNetworkMeteredness() throws Exception {
-        if (mSsidAndInitialMeteredness != null) {
-            resetMeteredNetwork(mSsidAndInitialMeteredness.first,
-                    mSsidAndInitialMeteredness.second);
+        if (mMeterednessResetter != null) {
+            mMeterednessResetter.run();
+            mMeterednessResetter = null;
         }
     }
 }
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
index 2ac29e7..955317b 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
@@ -17,16 +17,13 @@
 package com.android.cts.net.hostside;
 
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
+
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.canChangeActiveNetworkMeteredness;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setRestrictBackground;
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.isActiveNetworkMetered;
 import static com.android.cts.net.hostside.Property.BATTERY_SAVER_MODE;
 import static com.android.cts.net.hostside.Property.DATA_SAVER_MODE;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.junit.Assume.assumeTrue;
 
@@ -186,7 +183,7 @@
     public void setUp() throws Exception {
         super.setUp();
 
-        assumeTrue(isActiveNetworkMetered(true) || canChangeActiveNetworkMeteredness());
+        assumeTrue(canChangeActiveNetworkMeteredness());
 
         registerBroadcastReceiver();
 
@@ -198,13 +195,13 @@
         setBatterySaverMode(false);
         setRestrictBackground(false);
 
-        // Make wifi a metered network.
+        // Mark network as metered.
         mMeterednessConfiguration.configureNetworkMeteredness(true);
 
         // Register callback
         registerNetworkCallback((INetworkCallback.Stub) mTestNetworkCallback);
-        // Once the wifi is marked as metered, the wifi will reconnect. Wait for onAvailable()
-        // callback to ensure wifi is connected before the test and store the default network.
+        // Wait for onAvailable() callback to ensure network is available before the test
+        // and store the default network.
         mNetwork = mTestNetworkCallback.expectAvailableCallbackAndGetNetwork();
         // Check that the network is metered.
         mTestNetworkCallback.expectCapabilitiesCallbackEventually(mNetwork,
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
index 3041dfa..b61535b 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
@@ -20,6 +20,7 @@
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_ENABLED;
 import static android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_WHITELISTED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
@@ -28,6 +29,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -40,25 +42,36 @@
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.wifi.WifiManager;
+import android.os.PersistableBundle;
 import android.os.Process;
+import android.telephony.CarrierConfigManager;
+import android.telephony.SubscriptionManager;
+import android.telephony.data.ApnSetting;
 import android.text.TextUtils;
 import android.util.Log;
-import android.util.Pair;
+
+import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.compatibility.common.util.AppStandbyUtils;
 import com.android.compatibility.common.util.BatteryUtils;
+import com.android.compatibility.common.util.ShellIdentityUtils;
+import com.android.compatibility.common.util.ThrowingRunnable;
 
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
-import androidx.test.platform.app.InstrumentationRegistry;
-
 public class NetworkPolicyTestUtils {
 
+    // android.telephony.CarrierConfigManager.KEY_CARRIER_METERED_APN_TYPES_STRINGS
+    // TODO: Expose it as a @TestApi instead of copying the constant
+    private static final String KEY_CARRIER_METERED_APN_TYPES_STRINGS =
+            "carrier_metered_apn_types_strings";
+
     private static final int TIMEOUT_CHANGE_METEREDNESS_MS = 10_000;
 
     private static ConnectivityManager mCm;
     private static WifiManager mWm;
+    private static CarrierConfigManager mCarrierConfigManager;
 
     private static Boolean mBatterySaverSupported;
     private static Boolean mDataSaverSupported;
@@ -135,16 +148,40 @@
     }
 
     public static boolean canChangeActiveNetworkMeteredness() {
-        final Network activeNetwork = getConnectivityManager().getActiveNetwork();
-        final NetworkCapabilities networkCapabilities
-                = getConnectivityManager().getNetworkCapabilities(activeNetwork);
-        return networkCapabilities.hasTransport(TRANSPORT_WIFI);
+        final NetworkCapabilities networkCapabilities = getActiveNetworkCapabilities();
+        return networkCapabilities.hasTransport(TRANSPORT_WIFI)
+                || networkCapabilities.hasTransport(TRANSPORT_CELLULAR);
     }
 
-    public static Pair<String, Boolean> setupMeteredNetwork(boolean metered) throws Exception {
+    /**
+     * Updates the meteredness of the active network. Right now we can only change meteredness
+     * of either Wifi or cellular network, so if the active network is not either of these, this
+     * will throw an exception.
+     *
+     * @return a {@link ThrowingRunnable} object that can used to reset the meteredness change
+     *         made by this method.
+     */
+    public static ThrowingRunnable setupActiveNetworkMeteredness(boolean metered) throws Exception {
         if (isActiveNetworkMetered(metered)) {
             return null;
         }
+        final NetworkCapabilities networkCapabilities = getActiveNetworkCapabilities();
+        if (networkCapabilities.hasTransport(TRANSPORT_WIFI)) {
+            final String ssid = getWifiSsid();
+            setWifiMeteredStatus(ssid, metered);
+            return () -> setWifiMeteredStatus(ssid, !metered);
+        } else if (networkCapabilities.hasTransport(TRANSPORT_CELLULAR)) {
+            final int subId = SubscriptionManager.getActiveDataSubscriptionId();
+            setCellularMeteredStatus(subId, metered);
+            return () -> setCellularMeteredStatus(subId, !metered);
+        } else {
+            // Right now, we don't have a way to change meteredness of networks other
+            // than Wi-Fi or Cellular, so just throw an exception.
+            throw new IllegalStateException("Can't change meteredness of current active network");
+        }
+    }
+
+    private static String getWifiSsid() {
         final boolean isLocationEnabled = isLocationEnabled();
         try {
             if (!isLocationEnabled) {
@@ -152,8 +189,7 @@
             }
             final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
             assertNotEquals(WifiManager.UNKNOWN_SSID, ssid);
-            setWifiMeteredStatus(ssid, metered);
-            return Pair.create(ssid, !metered);
+            return ssid;
         } finally {
             // Reset the location enabled state
             if (!isLocationEnabled) {
@@ -162,11 +198,13 @@
         }
     }
 
-    public static void resetMeteredNetwork(String ssid, boolean metered) throws Exception {
-        setWifiMeteredStatus(ssid, metered);
+    private static NetworkCapabilities getActiveNetworkCapabilities() {
+        final Network activeNetwork = getConnectivityManager().getActiveNetwork();
+        assertNotNull("No active network available", activeNetwork);
+        return getConnectivityManager().getNetworkCapabilities(activeNetwork);
     }
 
-    public static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
+    private static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
         assertFalse("SSID should not be empty", TextUtils.isEmpty(ssid));
         final String cmd = "cmd netpolicy set metered-network " + ssid + " " + metered;
         executeShellCommand(cmd);
@@ -174,15 +212,24 @@
         assertActiveNetworkMetered(metered);
     }
 
-    public static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
+    private static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
         final String result = executeShellCommand("cmd netpolicy list wifi-networks");
         final String expectedLine = ssid + ";" + expectedMeteredStatus;
         assertTrue("Expected line: " + expectedLine + "; Actual result: " + result,
                 result.contains(expectedLine));
     }
 
+    private static void setCellularMeteredStatus(int subId, boolean metered) throws Exception {
+        final PersistableBundle bundle = new PersistableBundle();
+        bundle.putStringArray(KEY_CARRIER_METERED_APN_TYPES_STRINGS,
+                new String[] {ApnSetting.TYPE_MMS_STRING});
+        ShellIdentityUtils.invokeMethodWithShellPermissionsNoReturn(getCarrierConfigManager(),
+                (cm) -> cm.overrideConfig(subId, metered ? null : bundle));
+        assertActiveNetworkMetered(metered);
+    }
+
     // Copied from cts/tests/tests/net/src/android/net/cts/ConnectivityManagerTest.java
-    public static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
+    private static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
         final CountDownLatch latch = new CountDownLatch(1);
         final NetworkCallback networkCallback = new NetworkCallback() {
             @Override
@@ -197,12 +244,15 @@
         // with the current setting. Therefore, if the setting has already been changed,
         // this method will return right away, and if not it will wait for the setting to change.
         getConnectivityManager().registerDefaultNetworkCallback(networkCallback);
-        if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
-            fail("Timed out waiting for active network metered status to change to "
-                    + expectedMeteredStatus + " ; network = "
-                    + getConnectivityManager().getActiveNetwork());
+        try {
+            if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
+                fail("Timed out waiting for active network metered status to change to "
+                        + expectedMeteredStatus + "; network = "
+                        + getConnectivityManager().getActiveNetwork());
+            }
+        } finally {
+            getConnectivityManager().unregisterNetworkCallback(networkCallback);
         }
-        getConnectivityManager().unregisterNetworkCallback(networkCallback);
     }
 
     public static void setRestrictBackground(boolean enabled) {
@@ -274,6 +324,14 @@
         return mWm;
     }
 
+    public static CarrierConfigManager getCarrierConfigManager() {
+        if (mCarrierConfigManager == null) {
+            mCarrierConfigManager = (CarrierConfigManager) getContext().getSystemService(
+                    Context.CARRIER_CONFIG_SERVICE);
+        }
+        return mCarrierConfigManager;
+    }
+
     public static Context getContext() {
         return getInstrumentation().getContext();
     }
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 81a431c..4668ba3 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -16,6 +16,8 @@
 
 package com.android.cts.net.hostside;
 
+import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
 import static android.os.Process.INVALID_UID;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
@@ -25,6 +27,9 @@
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.POLLIN;
 import static android.system.OsConstants.SOCK_DGRAM;
+import static android.test.MoreAsserts.assertNotEqual;
+
+import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
 
 import android.annotation.Nullable;
 import android.app.DownloadManager;
@@ -45,9 +50,14 @@
 import android.net.NetworkRequest;
 import android.net.Proxy;
 import android.net.ProxyInfo;
+import android.net.TransportInfo;
 import android.net.Uri;
+import android.net.VpnManager;
 import android.net.VpnService;
+import android.net.VpnTransportInfo;
 import android.net.wifi.WifiManager;
+import android.os.Handler;
+import android.os.Looper;
 import android.os.ParcelFileDescriptor;
 import android.os.Process;
 import android.os.SystemProperties;
@@ -687,6 +697,20 @@
         setAndVerifyPrivateDns(initialMode);
     }
 
+    private class NeverChangeNetworkCallback extends NetworkCallback {
+        private volatile Network mLastNetwork;
+
+        public void onAvailable(Network n) {
+            assertNull("Callback got onAvailable more than once: " + mLastNetwork + ", " + n,
+                    mLastNetwork);
+            mLastNetwork = n;
+        }
+
+        public Network getLastNetwork() {
+            return mLastNetwork;
+        }
+    }
+
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
         // If adb TCP port opened, this test may running by adb over network.
@@ -702,6 +726,14 @@
                 getInstrumentation().getTargetContext(), MyVpnService.ACTION_ESTABLISHED);
         receiver.register();
 
+
+        // Expect the system default network not to change.
+        final NeverChangeNetworkCallback neverChangeCallback = new NeverChangeNetworkCallback();
+        final Network defaultNetwork = mCM.getActiveNetwork();
+        runWithShellPermissionIdentity(() ->
+                mCM.registerSystemDefaultNetworkCallback(neverChangeCallback,
+                        new Handler(Looper.getMainLooper())), NETWORK_SETTINGS);
+
         FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
 
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
@@ -719,6 +751,19 @@
 
         checkTrafficOnVpn();
 
+        expectVpnTransportInfo(mCM.getActiveNetwork());
+
+        // Check that system default network callback has not seen any network changes, but the app
+        // default network callback has. This needs to be done before testing private DNS because
+        // checkStrictModePrivateDns will set the private DNS server to a nonexistent name, which
+        // will cause validation to fail could cause the default network to switch (e.g., from wifi
+        // to cellular).
+        assertEquals(defaultNetwork, neverChangeCallback.getLastNetwork());
+        assertNotEqual(defaultNetwork, mCM.getActiveNetwork());
+        runWithShellPermissionIdentity(
+                () ->  mCM.unregisterNetworkCallback(neverChangeCallback),
+                NETWORK_SETTINGS);
+
         checkStrictModePrivateDns();
 
         receiver.unregisterQuietly();
@@ -739,6 +784,8 @@
 
         checkTrafficOnVpn();
 
+        expectVpnTransportInfo(mCM.getActiveNetwork());
+
         checkStrictModePrivateDns();
     }
 
@@ -764,6 +811,10 @@
         assertSocketStillOpen(remoteFd, TEST_HOST);
 
         checkNoTrafficOnVpn();
+
+        final Network network = mCM.getActiveNetwork();
+        final NetworkCapabilities nc = mCM.getNetworkCapabilities(network);
+        assertFalse(nc.hasTransport(TRANSPORT_VPN));
     }
 
     public void testGetConnectionOwnerUidSecurity() throws Exception {
@@ -778,8 +829,11 @@
         InetSocketAddress rem = new InetSocketAddress(s.getInetAddress(), s.getPort());
         try {
             int uid = mCM.getConnectionOwnerUid(OsConstants.IPPROTO_TCP, loc, rem);
-            fail("Only an active VPN app may call this API.");
-        } catch (SecurityException expected) {
+            assertEquals("Only an active VPN app should see connection information",
+                    INVALID_UID, uid);
+        } catch (SecurityException acceptable) {
+            // R and below throw SecurityException if a non-active VPN calls this method.
+            // As long as we can't actually get socket information, either behaviour is fine.
             return;
         }
     }
@@ -918,6 +972,8 @@
         // VPN with no underlying networks should be metered by default.
         assertTrue(isNetworkMetered(mNetwork));
         assertTrue(mCM.isActiveNetworkMetered());
+
+        expectVpnTransportInfo(mCM.getActiveNetwork());
     }
 
     public void testVpnMeterednessWithNullUnderlyingNetwork() throws Exception {
@@ -944,6 +1000,8 @@
         assertEquals(isNetworkMetered(underlyingNetwork), isNetworkMetered(mNetwork));
         // Meteredness based on VPN capabilities and CM#isActiveNetworkMetered should be in sync.
         assertEquals(isNetworkMetered(mNetwork), mCM.isActiveNetworkMetered());
+
+        expectVpnTransportInfo(mCM.getActiveNetwork());
     }
 
     public void testVpnMeterednessWithNonNullUnderlyingNetwork() throws Exception {
@@ -971,6 +1029,8 @@
         assertEquals(isNetworkMetered(underlyingNetwork), isNetworkMetered(mNetwork));
         // Meteredness based on VPN capabilities and CM#isActiveNetworkMetered should be in sync.
         assertEquals(isNetworkMetered(mNetwork), mCM.isActiveNetworkMetered());
+
+        expectVpnTransportInfo(mCM.getActiveNetwork());
     }
 
     public void testAlwaysMeteredVpnWithNullUnderlyingNetwork() throws Exception {
@@ -995,6 +1055,8 @@
         // VPN's meteredness does not depend on underlying network since it is always metered.
         assertTrue(isNetworkMetered(mNetwork));
         assertTrue(mCM.isActiveNetworkMetered());
+
+        expectVpnTransportInfo(mCM.getActiveNetwork());
     }
 
     public void testAlwaysMeteredVpnWithNonNullUnderlyingNetwork() throws Exception {
@@ -1020,6 +1082,8 @@
         // VPN's meteredness does not depend on underlying network since it is always metered.
         assertTrue(isNetworkMetered(mNetwork));
         assertTrue(mCM.isActiveNetworkMetered());
+
+        expectVpnTransportInfo(mCM.getActiveNetwork());
     }
 
     public void testB141603906() throws Exception {
@@ -1069,6 +1133,14 @@
         }
     }
 
+    private void expectVpnTransportInfo(Network network) {
+        final NetworkCapabilities vpnNc = mCM.getNetworkCapabilities(network);
+        assertTrue(vpnNc.hasTransport(TRANSPORT_VPN));
+        final TransportInfo ti = vpnNc.getTransportInfo();
+        assertTrue(ti instanceof VpnTransportInfo);
+        assertEquals(VpnManager.TYPE_VPN_SERVICE, ((VpnTransportInfo) ti).type);
+    }
+
     private void assertDefaultProxy(ProxyInfo expected) {
         assertEquals("Incorrect proxy config.", expected, mCM.getDefaultProxy());
         String expectedHost = expected == null ? null : expected.getHost();
diff --git a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
index 590e17e..1c9ff05 100644
--- a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
+++ b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
@@ -138,7 +138,7 @@
                     }
                 }
             };
-            mCm.registerNetworkCallback(makeWifiNetworkRequest(), mNetworkCallback);
+            mCm.registerNetworkCallback(makeNetworkRequest(), mNetworkCallback);
             try {
                 cb.asBinder().linkToDeath(() -> unregisterNetworkCallback(), 0);
             } catch (RemoteException e) {
@@ -156,9 +156,8 @@
         }
       };
 
-    private NetworkRequest makeWifiNetworkRequest() {
+    private NetworkRequest makeNetworkRequest() {
         return new NetworkRequest.Builder()
-                .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
                 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build();
     }
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index bd52bf9..3145d7e 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -93,6 +93,7 @@
 import android.net.wifi.WifiManager;
 import android.os.Binder;
 import android.os.Build;
+import android.os.Handler;
 import android.os.Looper;
 import android.os.MessageQueue;
 import android.os.SystemClock;
@@ -532,6 +533,13 @@
         final TestNetworkCallback defaultTrackingCallback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(defaultTrackingCallback);
 
+        final TestNetworkCallback systemDefaultTrackingCallback = new TestNetworkCallback();
+        runWithShellPermissionIdentity(() ->
+                mCm.registerSystemDefaultNetworkCallback(systemDefaultTrackingCallback,
+                        new Handler(Looper.getMainLooper())),
+                NETWORK_SETTINGS);
+
+
         Network wifiNetwork = null;
 
         try {
@@ -551,6 +559,9 @@
         } finally {
             mCm.unregisterNetworkCallback(callback);
             mCm.unregisterNetworkCallback(defaultTrackingCallback);
+            runWithShellPermissionIdentity(
+                    () -> mCm.unregisterNetworkCallback(systemDefaultTrackingCallback),
+                    NETWORK_SETTINGS);
         }
     }
 
@@ -1547,6 +1558,7 @@
      * Verify background request can only be requested when acquiring
      * {@link android.Manifest.permission.NETWORK_SETTINGS}.
      */
+    @SkipPresubmit(reason = "Flaky: b/179554972; add to presubmit after fixing")
     @Test
     public void testRequestBackgroundNetwork() throws Exception {
         // Create a tun interface. Use the returned interface name as the specifier to create
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 69d90aa..41537a9 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -35,6 +35,7 @@
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN
 import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
 import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
@@ -47,6 +48,8 @@
 import android.net.SocketKeepalive
 import android.net.StringNetworkSpecifier
 import android.net.Uri
+import android.net.VpnManager
+import android.net.VpnTransportInfo
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAutomaticReconnectDisabled
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested
@@ -321,6 +324,7 @@
             addCapability(NET_CAPABILITY_NOT_SUSPENDED)
             addCapability(NET_CAPABILITY_NOT_ROAMING)
             addCapability(NET_CAPABILITY_NOT_VPN)
+            addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
             if (null != name) {
                 setNetworkSpecifier(StringNetworkSpecifier(name))
             }
@@ -543,7 +547,7 @@
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.R)
-    fun testSetUnderlyingNetworks() {
+    fun testSetUnderlyingNetworksAndVpnSpecifier() {
         val request = NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_TEST)
                 .addTransportType(TRANSPORT_VPN)
@@ -557,6 +561,8 @@
             addTransportType(TRANSPORT_TEST)
             addTransportType(TRANSPORT_VPN)
             removeCapability(NET_CAPABILITY_NOT_VPN)
+            addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+            setTransportInfo(VpnTransportInfo(VpnManager.TYPE_VPN_SERVICE))
         }
         val defaultNetwork = mCM.activeNetwork
         assertNotNull(defaultNetwork)
@@ -571,6 +577,8 @@
         // Check that the default network's transport is propagated to the VPN.
         var vpnNc = mCM.getNetworkCapabilities(agent.network)
         assertNotNull(vpnNc)
+        assertEquals(VpnManager.TYPE_VPN_SERVICE,
+                (vpnNc.transportInfo as VpnTransportInfo).type)
 
         val testAndVpn = intArrayOf(TRANSPORT_TEST, TRANSPORT_VPN)
         assertTrue(hasAllTransports(vpnNc, testAndVpn))
diff --git a/tests/cts/net/src/android/net/cts/NetworkRequestTest.java b/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
index d118c8a..31dc64d 100644
--- a/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
@@ -16,8 +16,13 @@
 
 package android.net.cts;
 
+import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_FOTA;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_MMS;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_SUPL;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED;
 import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
@@ -29,6 +34,7 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
+import android.annotation.NonNull;
 import android.net.MacAddress;
 import android.net.MatchAllNetworkSpecifier;
 import android.net.NetworkCapabilities;
@@ -43,6 +49,7 @@
 
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.modules.utils.build.SdkLevel;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 
@@ -152,29 +159,44 @@
                 .getRequestorPackageName());
     }
 
+    private void addNotVcnManagedCapability(@NonNull NetworkCapabilities nc) {
+        if (SdkLevel.isAtLeastS()) {
+            nc.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED);
+        }
+    }
+
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.Q)
     public void testCanBeSatisfiedBy() {
         final LocalNetworkSpecifier specifier1 = new LocalNetworkSpecifier(1234 /* id */);
         final LocalNetworkSpecifier specifier2 = new LocalNetworkSpecifier(5678 /* id */);
 
+        // Some requests are adding NOT_VCN_MANAGED capability automatically. Add it to the
+        // capabilities below for bypassing the check.
         final NetworkCapabilities capCellularMmsInternet = new NetworkCapabilities()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .addCapability(NET_CAPABILITY_MMS)
                 .addCapability(NET_CAPABILITY_INTERNET);
+        addNotVcnManagedCapability(capCellularMmsInternet);
         final NetworkCapabilities capCellularVpnMmsInternet =
                 new NetworkCapabilities(capCellularMmsInternet).addTransportType(TRANSPORT_VPN);
+        addNotVcnManagedCapability(capCellularVpnMmsInternet);
         final NetworkCapabilities capCellularMmsInternetSpecifier1 =
                 new NetworkCapabilities(capCellularMmsInternet).setNetworkSpecifier(specifier1);
+        addNotVcnManagedCapability(capCellularMmsInternetSpecifier1);
         final NetworkCapabilities capVpnInternetSpecifier1 = new NetworkCapabilities()
                 .addCapability(NET_CAPABILITY_INTERNET)
                 .addTransportType(TRANSPORT_VPN)
                 .setNetworkSpecifier(specifier1);
+        addNotVcnManagedCapability(capVpnInternetSpecifier1);
         final NetworkCapabilities capCellularMmsInternetMatchallspecifier =
                 new NetworkCapabilities(capCellularMmsInternet)
-                    .setNetworkSpecifier(new MatchAllNetworkSpecifier());
+                        .setNetworkSpecifier(new MatchAllNetworkSpecifier());
+        addNotVcnManagedCapability(capCellularMmsInternetMatchallspecifier);
         final NetworkCapabilities capCellularMmsInternetSpecifier2 =
-                new NetworkCapabilities(capCellularMmsInternet).setNetworkSpecifier(specifier2);
+                new NetworkCapabilities(capCellularMmsInternet)
+                        .setNetworkSpecifier(specifier2);
+        addNotVcnManagedCapability(capCellularMmsInternetSpecifier2);
 
         final NetworkRequest requestCellularInternetSpecifier1 = new NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
@@ -239,7 +261,8 @@
 
         final NetworkCapabilities capCellInternetBWSpecifier1Signal =
                 new NetworkCapabilities.Builder(capCellInternetBWSpecifier1)
-                    .setSignalStrength(-123).build();
+                        .setSignalStrength(-123).build();
+        addNotVcnManagedCapability(capCellInternetBWSpecifier1Signal);
         assertCorrectlySatisfies(true, requestCombination,
                 capCellInternetBWSpecifier1Signal);
 
@@ -273,4 +296,75 @@
         assertEquals(Process.INVALID_UID, new NetworkRequest.Builder()
                 .clearCapabilities().build().getRequestorUid());
     }
+
+    // TODO: 1. Refactor test cases with helper method.
+    //       2. Test capability that does not yet exist.
+    @Test @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testBypassingVcnForNonInternetRequest() {
+        // Make an empty request. Verify the NOT_VCN_MANAGED is added.
+        final NetworkRequest emptyRequest = new NetworkRequest.Builder().build();
+        assertTrue(emptyRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a request explicitly add NOT_VCN_MANAGED. Verify the NOT_VCN_MANAGED is preserved.
+        final NetworkRequest mmsAddNotVcnRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_MMS)
+                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+                .build();
+        assertTrue(mmsAddNotVcnRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Similar to above, but the opposite order.
+        final NetworkRequest mmsAddNotVcnRequest2 = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+                .addCapability(NET_CAPABILITY_MMS)
+                .build();
+        assertTrue(mmsAddNotVcnRequest2.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a request explicitly remove NOT_VCN_MANAGED. Verify the NOT_VCN_MANAGED is removed.
+        final NetworkRequest removeNotVcnRequest = new NetworkRequest.Builder()
+                .removeCapability(NET_CAPABILITY_NOT_VCN_MANAGED).build();
+        assertFalse(removeNotVcnRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a request add some capability inside VCN supported capabilities.
+        // Verify the NOT_VCN_MANAGED is added.
+        final NetworkRequest notRoamRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_NOT_ROAMING).build();
+        assertTrue(notRoamRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a internet request. Verify the NOT_VCN_MANAGED is added.
+        final NetworkRequest internetRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_INTERNET).build();
+        assertTrue(internetRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a internet request which explicitly removed NOT_VCN_MANAGED.
+        // Verify the NOT_VCN_MANAGED is removed.
+        final NetworkRequest internetRemoveNotVcnRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_INTERNET)
+                .removeCapability(NET_CAPABILITY_NOT_VCN_MANAGED).build();
+        assertFalse(internetRemoveNotVcnRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a normal MMS request. Verify the request could bypass VCN.
+        final NetworkRequest mmsRequest =
+                new NetworkRequest.Builder().addCapability(NET_CAPABILITY_MMS).build();
+        assertFalse(mmsRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a SUPL request along with internet. Verify NOT_VCN_MANAGED is not added since
+        // SUPL is not in the supported list.
+        final NetworkRequest suplWithInternetRequest = new NetworkRequest.Builder()
+                        .addCapability(NET_CAPABILITY_SUPL)
+                        .addCapability(NET_CAPABILITY_INTERNET).build();
+        assertFalse(suplWithInternetRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a FOTA request with explicitly add NOT_VCN_MANAGED capability. Verify
+        // NOT_VCN_MANAGED is preserved.
+        final NetworkRequest fotaRequest = new NetworkRequest.Builder()
+                        .addCapability(NET_CAPABILITY_FOTA)
+                        .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED).build();
+        assertTrue(fotaRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+
+        // Make a DUN request, which is in {@code VCN_SUPPORTED_CAPABILITIES}.
+        // Verify NOT_VCN_MANAGED is preserved.
+        final NetworkRequest dunRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_DUN).build();
+        assertTrue(dunRequest.hasCapability(NET_CAPABILITY_NOT_VCN_MANAGED));
+    }
 }