Merge "Ignore testFactoryReset in instant app mode" into sc-dev
diff --git a/tests/cts/net/Android.bp b/tests/cts/net/Android.bp
index 25596a9..85942b0 100644
--- a/tests/cts/net/Android.bp
+++ b/tests/cts/net/Android.bp
@@ -41,6 +41,7 @@
     srcs: [
         "src/**/*.java",
         "src/**/*.kt",
+        ":ike-aes-xcbc",
     ],
     jarjar_rules: "jarjar-rules-shared.txt",
     static_libs: [
diff --git a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
index c6d8d65..6e9f0cd 100644
--- a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
+++ b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
@@ -39,13 +39,11 @@
 import android.net.ConnectivityManager;
 import android.net.Ikev2VpnProfile;
 import android.net.IpSecAlgorithm;
-import android.net.LinkAddress;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.ProxyInfo;
 import android.net.TestNetworkInterface;
-import android.net.TestNetworkManager;
 import android.net.VpnManager;
 import android.net.cts.util.CtsNetUtils;
 import android.os.Build;
@@ -439,41 +437,37 @@
         assertEquals(vpnNetwork, cb.lastLostNetwork);
     }
 
-    private void doTestStartStopVpnProfile(boolean testIpv6) throws Exception {
-        // Non-final; these variables ensure we clean up properly after our test if we have
-        // allocated test network resources
-        final TestNetworkManager tnm = sContext.getSystemService(TestNetworkManager.class);
-        TestNetworkInterface testIface = null;
-        TestNetworkCallback tunNetworkCallback = null;
+    private class VerifyStartStopVpnProfileTest implements TestNetworkRunnable.Test {
+        private final boolean mTestIpv6Only;
 
-        try {
-            // Build underlying test network
-            testIface = tnm.createTunInterface(
-                    new LinkAddress[] {
-                            new LinkAddress(LOCAL_OUTER_4, IP4_PREFIX_LEN),
-                            new LinkAddress(LOCAL_OUTER_6, IP6_PREFIX_LEN)});
+        /**
+         * Constructs the test
+         *
+         * @param testIpv6Only if true, builds a IPv6-only test; otherwise builds a IPv4-only test
+         */
+        VerifyStartStopVpnProfileTest(boolean testIpv6Only) {
+            mTestIpv6Only = testIpv6Only;
+        }
 
-            // Hold on to this callback to ensure network does not get reaped.
-            tunNetworkCallback = mCtsNetUtils.setupAndGetTestNetwork(
-                    testIface.getInterfaceName());
+        @Override
+        public void runTest(TestNetworkInterface testIface, TestNetworkCallback tunNetworkCallback)
+                throws Exception {
             final IkeTunUtils tunUtils = new IkeTunUtils(testIface.getFileDescriptor());
 
-            checkStartStopVpnProfileBuildsNetworks(tunUtils, testIpv6);
-        } finally {
-            // Make sure to stop the VPN profile. This is safe to call multiple times.
+            checkStartStopVpnProfileBuildsNetworks(tunUtils, mTestIpv6Only);
+        }
+
+        @Override
+        public void cleanupTest() {
             sVpnMgr.stopProvisionedVpnProfile();
+        }
 
-            if (testIface != null) {
-                testIface.getFileDescriptor().close();
-            }
-
-            if (tunNetworkCallback != null) {
-                sCM.unregisterNetworkCallback(tunNetworkCallback);
-            }
-
-            final Network testNetwork = tunNetworkCallback.currentNetwork;
-            if (testNetwork != null) {
-                tnm.teardownTestNetwork(testNetwork);
+        @Override
+        public InetAddress[] getTestNetworkAddresses() {
+            if (mTestIpv6Only) {
+                return new InetAddress[] {LOCAL_OUTER_6};
+            } else {
+                return new InetAddress[] {LOCAL_OUTER_4};
             }
         }
     }
@@ -483,9 +477,8 @@
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
 
         // Requires shell permission to update appops.
-        runWithShellPermissionIdentity(() -> {
-            doTestStartStopVpnProfile(false);
-        });
+        runWithShellPermissionIdentity(
+                new TestNetworkRunnable(new VerifyStartStopVpnProfileTest(false)));
     }
 
     @Test
@@ -493,9 +486,8 @@
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
 
         // Requires shell permission to update appops.
-        runWithShellPermissionIdentity(() -> {
-            doTestStartStopVpnProfile(true);
-        });
+        runWithShellPermissionIdentity(
+                new TestNetworkRunnable(new VerifyStartStopVpnProfileTest(true)));
     }
 
     private static class CertificateAndKey {
diff --git a/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java b/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java
new file mode 100644
index 0000000..2f29273
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java
@@ -0,0 +1,306 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.cts;
+
+import static android.net.IpSecAlgorithm.AUTH_AES_CMAC;
+import static android.net.IpSecAlgorithm.AUTH_AES_XCBC;
+import static android.net.IpSecAlgorithm.AUTH_CRYPT_CHACHA20_POLY1305;
+import static android.net.IpSecAlgorithm.CRYPT_AES_CTR;
+import static android.net.cts.PacketUtils.AES_CMAC;
+import static android.net.cts.PacketUtils.AES_CMAC_ICV_LEN;
+import static android.net.cts.PacketUtils.AES_CMAC_KEY_LEN;
+import static android.net.cts.PacketUtils.AES_CTR;
+import static android.net.cts.PacketUtils.AES_CTR_BLK_SIZE;
+import static android.net.cts.PacketUtils.AES_CTR_IV_LEN;
+import static android.net.cts.PacketUtils.AES_CTR_KEY_LEN_20;
+import static android.net.cts.PacketUtils.AES_CTR_KEY_LEN_28;
+import static android.net.cts.PacketUtils.AES_CTR_KEY_LEN_36;
+import static android.net.cts.PacketUtils.AES_CTR_SALT_LEN;
+import static android.net.cts.PacketUtils.AES_XCBC;
+import static android.net.cts.PacketUtils.AES_XCBC_ICV_LEN;
+import static android.net.cts.PacketUtils.AES_XCBC_KEY_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_BLK_SIZE;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_ICV_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_IV_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_KEY_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_SALT_LEN;
+import static android.net.cts.PacketUtils.ESP_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.getIpHeader;
+import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
+
+import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assume.assumeTrue;
+
+import android.net.IpSecAlgorithm;
+import android.net.IpSecManager;
+import android.net.IpSecTransform;
+import android.net.Network;
+import android.net.TestNetworkInterface;
+import android.net.cts.PacketUtils.BytePayload;
+import android.net.cts.PacketUtils.EspAeadCipher;
+import android.net.cts.PacketUtils.EspAuth;
+import android.net.cts.PacketUtils.EspAuthNull;
+import android.net.cts.PacketUtils.EspCipher;
+import android.net.cts.PacketUtils.EspCipherNull;
+import android.net.cts.PacketUtils.EspCryptCipher;
+import android.net.cts.PacketUtils.EspHeader;
+import android.net.cts.PacketUtils.IpHeader;
+import android.net.cts.PacketUtils.UdpHeader;
+import android.platform.test.annotations.AppModeFull;
+
+import androidx.test.InstrumentationRegistry;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.net.DatagramPacket;
+import java.net.DatagramSocket;
+import java.net.InetAddress;
+import java.util.Arrays;
+
+@RunWith(AndroidJUnit4.class)
+@AppModeFull(reason = "Socket cannot bind in instant app mode")
+public class IpSecAlgorithmImplTest extends IpSecBaseTest {
+    private static final InetAddress LOCAL_ADDRESS =
+            InetAddress.parseNumericAddress("2001:db8:1::1");
+    private static final InetAddress REMOTE_ADDRESS =
+            InetAddress.parseNumericAddress("2001:db8:1::2");
+
+    private static final int REMOTE_PORT = 12345;
+    private static final IpSecManager sIpSecManager =
+            InstrumentationRegistry.getContext().getSystemService(IpSecManager.class);
+
+    private static class CheckCryptoImplTest implements TestNetworkRunnable.Test {
+        private final IpSecAlgorithm mIpsecEncryptAlgo;
+        private final IpSecAlgorithm mIpsecAuthAlgo;
+        private final IpSecAlgorithm mIpsecAeadAlgo;
+        private final EspCipher mEspCipher;
+        private final EspAuth mEspAuth;
+
+        CheckCryptoImplTest(
+                IpSecAlgorithm ipsecEncryptAlgo,
+                IpSecAlgorithm ipsecAuthAlgo,
+                IpSecAlgorithm ipsecAeadAlgo,
+                EspCipher espCipher,
+                EspAuth espAuth) {
+            mIpsecEncryptAlgo = ipsecEncryptAlgo;
+            mIpsecAuthAlgo = ipsecAuthAlgo;
+            mIpsecAeadAlgo = ipsecAeadAlgo;
+            mEspCipher = espCipher;
+            mEspAuth = espAuth;
+        }
+
+        private static byte[] buildTransportModeEspPayload(
+                int srcPort, int dstPort, int spi, EspCipher espCipher, EspAuth espAuth)
+                throws Exception {
+            final UdpHeader udpPayload =
+                    new UdpHeader(srcPort, dstPort, new BytePayload(TEST_DATA));
+            final IpHeader preEspIpHeader =
+                    getIpHeader(
+                            udpPayload.getProtocolId(), LOCAL_ADDRESS, REMOTE_ADDRESS, udpPayload);
+
+            final PacketUtils.EspHeader espPayload =
+                    new EspHeader(
+                            udpPayload.getProtocolId(),
+                            spi,
+                            1 /* sequence number */,
+                            udpPayload.getPacketBytes(preEspIpHeader),
+                            espCipher,
+                            espAuth);
+            return espPayload.getPacketBytes(preEspIpHeader);
+        }
+
+        @Override
+        public void runTest(TestNetworkInterface testIface, TestNetworkCallback tunNetworkCallback)
+                throws Exception {
+            final TunUtils tunUtils = new TunUtils(testIface.getFileDescriptor());
+            tunNetworkCallback.waitForAvailable();
+            final Network testNetwork = tunNetworkCallback.currentNetwork;
+
+            final IpSecTransform.Builder transformBuilder =
+                    new IpSecTransform.Builder(InstrumentationRegistry.getContext());
+            if (mIpsecAeadAlgo != null) {
+                transformBuilder.setAuthenticatedEncryption(mIpsecAeadAlgo);
+            } else {
+                if (mIpsecEncryptAlgo != null) {
+                    transformBuilder.setEncryption(mIpsecEncryptAlgo);
+                }
+                if (mIpsecAuthAlgo != null) {
+                    transformBuilder.setAuthentication(mIpsecAuthAlgo);
+                }
+            }
+
+            try (IpSecManager.SecurityParameterIndex outSpi =
+                            sIpSecManager.allocateSecurityParameterIndex(REMOTE_ADDRESS);
+                    IpSecManager.SecurityParameterIndex inSpi =
+                            sIpSecManager.allocateSecurityParameterIndex(LOCAL_ADDRESS);
+                    IpSecTransform outTransform =
+                            transformBuilder.buildTransportModeTransform(LOCAL_ADDRESS, outSpi);
+                    IpSecTransform inTransform =
+                            transformBuilder.buildTransportModeTransform(REMOTE_ADDRESS, inSpi);
+                    // Bind localSocket to a random available port.
+                    DatagramSocket localSocket = new DatagramSocket(0)) {
+                sIpSecManager.applyTransportModeTransform(
+                        localSocket, IpSecManager.DIRECTION_IN, inTransform);
+                sIpSecManager.applyTransportModeTransform(
+                        localSocket, IpSecManager.DIRECTION_OUT, outTransform);
+
+                // Send ESP packet
+                final DatagramPacket outPacket =
+                        new DatagramPacket(
+                                TEST_DATA, 0, TEST_DATA.length, REMOTE_ADDRESS, REMOTE_PORT);
+                testNetwork.bindSocket(localSocket);
+                localSocket.send(outPacket);
+                final byte[] outEspPacket =
+                        tunUtils.awaitEspPacket(outSpi.getSpi(), false /* useEncap */);
+
+                // Remove transform for good hygiene
+                sIpSecManager.removeTransportModeTransforms(localSocket);
+
+                // Get the kernel-generated ESP payload
+                final byte[] outEspPayload = new byte[outEspPacket.length - IP6_HDRLEN];
+                System.arraycopy(outEspPacket, IP6_HDRLEN, outEspPayload, 0, outEspPayload.length);
+
+                // Get the IV of the kernel-generated ESP payload
+                final byte[] iv =
+                        Arrays.copyOfRange(
+                                outEspPayload, ESP_HDRLEN, ESP_HDRLEN + mEspCipher.ivLen);
+
+                // Build ESP payload using the kernel-generated IV and the user space crypto
+                // implementations
+                mEspCipher.updateIv(iv);
+                final byte[] expectedEspPayload =
+                        buildTransportModeEspPayload(
+                                localSocket.getLocalPort(),
+                                REMOTE_PORT,
+                                outSpi.getSpi(),
+                                mEspCipher,
+                                mEspAuth);
+
+                // Compare user-space-generated and kernel-generated ESP payload
+                assertArrayEquals(expectedEspPayload, outEspPayload);
+            }
+        }
+
+        @Override
+        public void cleanupTest() {
+            // Do nothing
+        }
+
+        @Override
+        public InetAddress[] getTestNetworkAddresses() {
+            return new InetAddress[] {LOCAL_ADDRESS};
+        }
+    }
+
+    private void checkAesCtr(int keyLen) throws Exception {
+        final byte[] cryptKey = getKeyBytes(keyLen);
+
+        final IpSecAlgorithm ipsecEncryptAlgo =
+                new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CTR, cryptKey);
+        final EspCipher espCipher =
+                new EspCryptCipher(
+                        AES_CTR, AES_CTR_BLK_SIZE, cryptKey, AES_CTR_IV_LEN, AES_CTR_SALT_LEN);
+
+        runWithShellPermissionIdentity(new TestNetworkRunnable(new CheckCryptoImplTest(
+                ipsecEncryptAlgo, null /* ipsecAuthAlgo */, null /* ipsecAeadAlgo */,
+                espCipher, EspAuthNull.getInstance())));
+    }
+
+    @Test
+    public void testAesCtr160() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(CRYPT_AES_CTR));
+
+        checkAesCtr(AES_CTR_KEY_LEN_20);
+    }
+
+    @Test
+    public void testAesCtr224() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(CRYPT_AES_CTR));
+
+        checkAesCtr(AES_CTR_KEY_LEN_28);
+    }
+
+    @Test
+    public void testAesCtr288() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(CRYPT_AES_CTR));
+
+        checkAesCtr(AES_CTR_KEY_LEN_36);
+    }
+
+    @Test
+    public void testAesXcbc() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(AUTH_AES_XCBC));
+
+        final byte[] authKey = getKeyBytes(AES_XCBC_KEY_LEN);
+        final IpSecAlgorithm ipsecAuthAlgo =
+                new IpSecAlgorithm(IpSecAlgorithm.AUTH_AES_XCBC, authKey, AES_XCBC_ICV_LEN * 8);
+        final EspAuth espAuth = new EspAuth(AES_XCBC, authKey, AES_XCBC_ICV_LEN);
+
+        runWithShellPermissionIdentity(new TestNetworkRunnable(new CheckCryptoImplTest(
+                null /* ipsecEncryptAlgo */, ipsecAuthAlgo, null /* ipsecAeadAlgo */,
+                EspCipherNull.getInstance(), espAuth)));
+    }
+
+    @Test
+    public void testAesCmac() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(AUTH_AES_CMAC));
+
+        final byte[] authKey = getKeyBytes(AES_CMAC_KEY_LEN);
+        final IpSecAlgorithm ipsecAuthAlgo =
+                new IpSecAlgorithm(IpSecAlgorithm.AUTH_AES_CMAC, authKey, AES_CMAC_ICV_LEN * 8);
+        final EspAuth espAuth = new EspAuth(AES_CMAC, authKey, AES_CMAC_ICV_LEN);
+
+        runWithShellPermissionIdentity(new TestNetworkRunnable(new CheckCryptoImplTest(
+                null /* ipsecEncryptAlgo */, ipsecAuthAlgo, null /* ipsecAeadAlgo */,
+                EspCipherNull.getInstance(), espAuth)));
+    }
+
+    @Test
+    public void testChaCha20Poly1305() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(AUTH_CRYPT_CHACHA20_POLY1305));
+
+        final byte[] cryptKey = getKeyBytes(CHACHA20_POLY1305_KEY_LEN);
+        final IpSecAlgorithm ipsecAeadAlgo =
+                new IpSecAlgorithm(
+                        IpSecAlgorithm.AUTH_CRYPT_CHACHA20_POLY1305,
+                        cryptKey,
+                        CHACHA20_POLY1305_ICV_LEN * 8);
+        final EspAeadCipher espAead =
+                new EspAeadCipher(
+                        CHACHA20_POLY1305,
+                        CHACHA20_POLY1305_BLK_SIZE,
+                        cryptKey,
+                        CHACHA20_POLY1305_IV_LEN,
+                        CHACHA20_POLY1305_ICV_LEN,
+                        CHACHA20_POLY1305_SALT_LEN);
+
+        runWithShellPermissionIdentity(
+                new TestNetworkRunnable(
+                        new CheckCryptoImplTest(
+                                null /* ipsecEncryptAlgo */,
+                                null /* ipsecAuthAlgo */,
+                                ipsecAeadAlgo,
+                                espAead,
+                                EspAuthNull.getInstance())));
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
index e7e1d67..5c95aa3 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
@@ -33,7 +33,7 @@
 import static android.net.cts.PacketUtils.AES_CMAC_KEY_LEN;
 import static android.net.cts.PacketUtils.AES_CTR_BLK_SIZE;
 import static android.net.cts.PacketUtils.AES_CTR_IV_LEN;
-import static android.net.cts.PacketUtils.AES_CTR_KEY_LEN;
+import static android.net.cts.PacketUtils.AES_CTR_KEY_LEN_20;
 import static android.net.cts.PacketUtils.AES_GCM_BLK_SIZE;
 import static android.net.cts.PacketUtils.AES_GCM_IV_LEN;
 import static android.net.cts.PacketUtils.AES_XCBC_ICV_LEN;
@@ -74,7 +74,6 @@
 
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
-import com.android.testutils.SkipPresubmit;
 
 import org.junit.Rule;
 import org.junit.Test;
@@ -771,7 +770,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcHmacMd5Tcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_MD5, getKey(128), 96);
@@ -804,7 +802,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcHmacSha1Tcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_SHA1, getKey(160), 96);
@@ -837,7 +834,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcHmacSha256Tcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_SHA256, getKey(256), 128);
@@ -870,7 +866,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcHmacSha384Tcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_SHA384, getKey(384), 192);
@@ -903,7 +898,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcHmacSha512Tcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_SHA512, getKey(512), 256);
@@ -928,7 +922,7 @@
     }
 
     private static IpSecAlgorithm buildCryptAesCtr() throws Exception {
-        return new IpSecAlgorithm(CRYPT_AES_CTR, getKeyBytes(AES_CTR_KEY_LEN));
+        return new IpSecAlgorithm(CRYPT_AES_CTR, getKeyBytes(AES_CTR_KEY_LEN_20));
     }
 
     private static IpSecAlgorithm buildAuthHmacSha512() throws Exception {
@@ -947,7 +941,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCtrHmacSha512Tcp6() throws Exception {
         assumeTrue(hasIpSecAlgorithm(CRYPT_AES_CTR));
 
@@ -1002,7 +995,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcAesXCbcTcp6() throws Exception {
         assumeTrue(hasIpSecAlgorithm(AUTH_AES_XCBC));
 
@@ -1043,7 +1035,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesCbcAesCmacTcp6() throws Exception {
         assumeTrue(hasIpSecAlgorithm(AUTH_AES_CMAC));
 
@@ -1082,7 +1073,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesGcm64Tcp6() throws Exception {
         IpSecAlgorithm authCrypt =
                 new IpSecAlgorithm(IpSecAlgorithm.AUTH_CRYPT_AES_GCM, AEAD_KEY, 64);
@@ -1115,7 +1105,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesGcm96Tcp6() throws Exception {
         IpSecAlgorithm authCrypt =
                 new IpSecAlgorithm(IpSecAlgorithm.AUTH_CRYPT_AES_GCM, AEAD_KEY, 96);
@@ -1148,7 +1137,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAesGcm128Tcp6() throws Exception {
         IpSecAlgorithm authCrypt =
                 new IpSecAlgorithm(IpSecAlgorithm.AUTH_CRYPT_AES_GCM, AEAD_KEY, 128);
@@ -1187,7 +1175,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testChaCha20Poly1305Tcp6() throws Exception {
         assumeTrue(hasIpSecAlgorithm(AUTH_CRYPT_CHACHA20_POLY1305));
 
@@ -1463,7 +1450,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testCryptTcp6() throws Exception {
         IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
         checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, null, null, false, 1, false);
@@ -1471,7 +1457,6 @@
     }
 
     @Test
-    @SkipPresubmit(reason = "b/186608065 - kernel 5.10 regression in TrafficStats with ipsec")
     public void testAuthTcp6() throws Exception {
         IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_SHA256, getKey(256), 128);
         checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, null, auth, null, false, 1, false);
diff --git a/tests/cts/net/src/android/net/cts/PacketUtils.java b/tests/cts/net/src/android/net/cts/PacketUtils.java
index 7e622f6..4d924d1 100644
--- a/tests/cts/net/src/android/net/cts/PacketUtils.java
+++ b/tests/cts/net/src/android/net/cts/PacketUtils.java
@@ -19,6 +19,10 @@
 import static android.system.OsConstants.IPPROTO_IPV6;
 import static android.system.OsConstants.IPPROTO_UDP;
 
+import android.util.ArraySet;
+
+import com.android.internal.net.ipsec.ike.crypto.AesXCbcImpl;
+
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
@@ -27,6 +31,7 @@
 import java.security.GeneralSecurityException;
 import java.security.SecureRandom;
 import java.util.Arrays;
+import java.util.Set;
 
 import javax.crypto.Cipher;
 import javax.crypto.Mac;
@@ -43,7 +48,9 @@
     static final int UDP_HDRLEN = 8;
     static final int TCP_HDRLEN = 20;
     static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12;
+    static final int ESP_HDRLEN = 8;
     static final int ESP_BLK_SIZE = 4; // ESP has to be 4-byte aligned
+    static final int ESP_TRAILER_LEN = 2;
 
     // Not defined in OsConstants
     static final int IPPROTO_IPV4 = 4;
@@ -52,19 +59,25 @@
     // Encryption parameters
     static final int AES_CBC_IV_LEN = 16;
     static final int AES_CBC_BLK_SIZE = 16;
+    static final int AES_CTR_SALT_LEN = 4;
 
-    static final int AES_CTR_KEY_LEN = 20;
+    static final int AES_CTR_KEY_LEN_20 = 20;
+    static final int AES_CTR_KEY_LEN_28 = 28;
+    static final int AES_CTR_KEY_LEN_36 = 36;
     static final int AES_CTR_BLK_SIZE = ESP_BLK_SIZE;
     static final int AES_CTR_IV_LEN = 8;
 
     // AEAD parameters
     static final int AES_GCM_IV_LEN = 8;
     static final int AES_GCM_BLK_SIZE = 4;
+    static final int CHACHA20_POLY1305_KEY_LEN = 36;
     static final int CHACHA20_POLY1305_BLK_SIZE = ESP_BLK_SIZE;
     static final int CHACHA20_POLY1305_IV_LEN = 8;
+    static final int CHACHA20_POLY1305_SALT_LEN = 4;
     static final int CHACHA20_POLY1305_ICV_LEN = 16;
 
     // Authentication parameters
+    static final int HMAC_SHA256_ICV_LEN = 16;
     static final int HMAC_SHA512_KEY_LEN = 64;
     static final int HMAC_SHA512_ICV_LEN = 32;
     static final int AES_XCBC_KEY_LEN = 16;
@@ -72,10 +85,25 @@
     static final int AES_CMAC_KEY_LEN = 16;
     static final int AES_CMAC_ICV_LEN = 12;
 
+    // Block counter field should be 32 bits and starts from value one as per RFC 3686
+    static final byte[] AES_CTR_INITIAL_COUNTER = new byte[] {0x00, 0x00, 0x00, 0x01};
+
     // Encryption algorithms
     static final String AES = "AES";
     static final String AES_CBC = "AES/CBC/NoPadding";
+    static final String AES_CTR = "AES/CTR/NoPadding";
+
+    // AEAD algorithms
+    static final String CHACHA20_POLY1305 = "ChaCha20/Poly1305/NoPadding";
+
+    // Authentication algorithms
+    static final String HMAC_MD5 = "HmacMD5";
+    static final String HMAC_SHA1 = "HmacSHA1";
     static final String HMAC_SHA_256 = "HmacSHA256";
+    static final String HMAC_SHA_384 = "HmacSHA384";
+    static final String HMAC_SHA_512 = "HmacSHA512";
+    static final String AES_CMAC = "AESCMAC";
+    static final String AES_XCBC = "AesXCbc";
 
     public interface Payload {
         byte[] getPacketBytes(IpHeader header) throws Exception;
@@ -328,8 +356,9 @@
         public final int nextHeader;
         public final int spi;
         public final int seqNum;
-        public final byte[] key;
         public final byte[] payload;
+        public final EspCipher cipher;
+        public final EspAuth auth;
 
         /**
          * Generic constructor for ESP headers.
@@ -340,11 +369,48 @@
          * calculated using the pre-encryption IP header
          */
         public EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload) {
+            this(nextHeader, spi, seqNum, payload, getDefaultCipher(key), getDefaultAuth(key));
+        }
+
+        /**
+         * Generic constructor for ESP headers that allows configuring encryption and authentication
+         * algortihms.
+         *
+         * <p>For Tunnel mode, payload will be a full IP header + attached payloads
+         *
+         * <p>For Transport mode, payload will be only the attached payloads, but with the checksum
+         * calculated using the pre-encryption IP header
+         */
+        public EspHeader(
+                int nextHeader,
+                int spi,
+                int seqNum,
+                byte[] payload,
+                EspCipher cipher,
+                EspAuth auth) {
             this.nextHeader = nextHeader;
             this.spi = spi;
             this.seqNum = seqNum;
-            this.key = key;
             this.payload = payload;
+            this.cipher = cipher;
+            this.auth = auth;
+
+            if (cipher instanceof EspCipherNull && auth instanceof EspAuthNull) {
+                throw new IllegalArgumentException("No algorithm is provided");
+            }
+
+            if (cipher instanceof EspAeadCipher && !(auth instanceof EspAuthNull)) {
+                throw new IllegalArgumentException(
+                        "AEAD is provided with an authentication" + " algorithm.");
+            }
+        }
+
+        private static EspCipher getDefaultCipher(byte[] key) {
+            return new EspCryptCipher(AES_CBC, AES_CBC_BLK_SIZE, key, AES_CBC_IV_LEN);
+        }
+
+        private static EspAuth getDefaultAuth(byte[] key) {
+            return new EspAuth(HMAC_SHA_256, key, HMAC_SHA256_ICV_LEN);
         }
 
         public int getProtocolId() {
@@ -352,9 +418,10 @@
         }
 
         public short length() {
-            // ALWAYS uses AES-CBC, HMAC-SHA256 (128b trunc len)
-            return (short)
-                    calculateEspPacketSize(payload.length, AES_CBC_IV_LEN, AES_CBC_BLK_SIZE, 128);
+            final int icvLen =
+                    cipher instanceof EspAeadCipher ? ((EspAeadCipher) cipher).icvLen : auth.icvLen;
+            return calculateEspPacketSize(
+                    payload.length, cipher.ivLen, cipher.blockSize, icvLen * 8);
         }
 
         public byte[] getPacketBytes(IpHeader header) throws Exception {
@@ -368,58 +435,12 @@
             ByteBuffer espPayloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
             espPayloadBuffer.putInt(spi);
             espPayloadBuffer.putInt(seqNum);
-            espPayloadBuffer.put(getCiphertext(key));
 
-            espPayloadBuffer.put(getIcv(getByteArrayFromBuffer(espPayloadBuffer)), 0, 16);
+            espPayloadBuffer.put(cipher.getCipherText(nextHeader, payload, spi, seqNum));
+            espPayloadBuffer.put(auth.getIcv(getByteArrayFromBuffer(espPayloadBuffer)));
+
             resultBuffer.put(getByteArrayFromBuffer(espPayloadBuffer));
         }
-
-        private byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
-            Mac sha256HMAC = Mac.getInstance(HMAC_SHA_256);
-            SecretKeySpec authKey = new SecretKeySpec(key, HMAC_SHA_256);
-            sha256HMAC.init(authKey);
-
-            return sha256HMAC.doFinal(authenticatedSection);
-        }
-
-        /**
-         * Encrypts and builds ciphertext block. Includes the IV, Padding and Next-Header blocks
-         *
-         * <p>The ciphertext does NOT include the SPI/Sequence numbers, or the ICV.
-         */
-        private byte[] getCiphertext(byte[] key) throws GeneralSecurityException {
-            int paddedLen = calculateEspEncryptedLength(payload.length, AES_CBC_BLK_SIZE);
-            ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
-            paddedPayload.put(payload);
-
-            // Add padding - consecutive integers from 0x01
-            int pad = 1;
-            while (paddedPayload.position() < paddedPayload.limit()) {
-                paddedPayload.put((byte) pad++);
-            }
-
-            paddedPayload.position(paddedPayload.limit() - 2);
-            paddedPayload.put((byte) (paddedLen - 2 - payload.length)); // Pad length
-            paddedPayload.put((byte) nextHeader);
-
-            // Generate Initialization Vector
-            byte[] iv = new byte[AES_CBC_IV_LEN];
-            new SecureRandom().nextBytes(iv);
-            IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
-            SecretKeySpec secretKeySpec = new SecretKeySpec(key, AES);
-
-            // Encrypt payload
-            Cipher cipher = Cipher.getInstance(AES_CBC);
-            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
-            byte[] encrypted = cipher.doFinal(getByteArrayFromBuffer(paddedPayload));
-
-            // Build ciphertext
-            ByteBuffer cipherText = ByteBuffer.allocate(AES_CBC_IV_LEN + encrypted.length);
-            cipherText.put(iv);
-            cipherText.put(encrypted);
-
-            return getByteArrayFromBuffer(cipherText);
-        }
     }
 
     private static int addAndWrapForChecksum(int currentChecksum, int value) {
@@ -436,15 +457,14 @@
         return (short) ((~val) & 0xffff);
     }
 
-    public static int calculateEspPacketSize(
+    public static short calculateEspPacketSize(
             int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
-        final int ESP_HDRLEN = 4 + 4; // SPI + Seq#
         final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
-        payloadLen += cryptIvLength; // Initialization Vector
 
         // Align to block size of encryption algorithm
         payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize);
-        return payloadLen + ESP_HDRLEN + ICV_LEN;
+        payloadLen += cryptIvLength; // Initialization Vector
+        return (short) (payloadLen + ESP_HDRLEN + ICV_LEN);
     }
 
     private static int calculateEspEncryptedLength(int payloadLen, int cryptBlockSize) {
@@ -475,6 +495,237 @@
         }
     }
 
+    public abstract static class EspCipher {
+        protected static final int SALT_LEN_UNUSED = 0;
+
+        public final String algoName;
+        public final int blockSize;
+        public final byte[] key;
+        public final int ivLen;
+        public final int saltLen;
+        protected byte[] mIv;
+
+        public EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) {
+            this.algoName = algoName;
+            this.blockSize = blockSize;
+            this.key = key;
+            this.ivLen = ivLen;
+            this.saltLen = saltLen;
+            this.mIv = getIv(ivLen);
+        }
+
+        public void updateIv(byte[] iv) {
+            this.mIv = iv;
+        }
+
+        public static byte[] getPaddedPayload(int nextHeader, byte[] payload, int blockSize) {
+            final int paddedLen = calculateEspEncryptedLength(payload.length, blockSize);
+            final ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
+            paddedPayload.put(payload);
+
+            // Add padding - consecutive integers from 0x01
+            byte pad = 1;
+            while (paddedPayload.position() < paddedPayload.limit() - ESP_TRAILER_LEN) {
+                paddedPayload.put((byte) pad++);
+            }
+
+            // Add padding length and next header
+            paddedPayload.put((byte) (paddedLen - ESP_TRAILER_LEN - payload.length));
+            paddedPayload.put((byte) nextHeader);
+
+            return getByteArrayFromBuffer(paddedPayload);
+        }
+
+        private static byte[] getIv(int ivLen) {
+            final byte[] iv = new byte[ivLen];
+            new SecureRandom().nextBytes(iv);
+            return iv;
+        }
+
+        public abstract byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
+                throws GeneralSecurityException;
+    }
+
+    public static final class EspCipherNull extends EspCipher {
+        private static final String CRYPT_NULL = "CRYPT_NULL";
+        private static final int IV_LEN_UNUSED = 0;
+        private static final byte[] KEY_UNUSED = new byte[0];
+
+        private static final EspCipherNull sInstance = new EspCipherNull();
+
+        private EspCipherNull() {
+            super(CRYPT_NULL, ESP_BLK_SIZE, KEY_UNUSED, IV_LEN_UNUSED, SALT_LEN_UNUSED);
+        }
+
+        public static EspCipherNull getInstance() {
+            return sInstance;
+        }
+
+        @Override
+        public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
+                throws GeneralSecurityException {
+            return getPaddedPayload(nextHeader, payload, blockSize);
+        }
+    }
+
+    public static final class EspCryptCipher extends EspCipher {
+        public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen) {
+            this(algoName, blockSize, key, ivLen, SALT_LEN_UNUSED);
+        }
+
+        public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) {
+            super(algoName, blockSize, key, ivLen, saltLen);
+        }
+
+        @Override
+        public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
+                throws GeneralSecurityException {
+            final IvParameterSpec ivParameterSpec;
+            final SecretKeySpec secretKeySpec;
+
+            if (AES_CBC.equals(algoName)) {
+                ivParameterSpec = new IvParameterSpec(mIv);
+                secretKeySpec = new SecretKeySpec(key, algoName);
+            } else if (AES_CTR.equals(algoName)) {
+                // Provided key consists of encryption/decryption key plus 4-byte salt. Salt is used
+                // with ESP payload IV and initial block counter value to build IvParameterSpec.
+                final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen);
+                final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length);
+                secretKeySpec = new SecretKeySpec(secretKey, algoName);
+
+                final ByteBuffer ivParameterBuffer =
+                        ByteBuffer.allocate(mIv.length + saltLen + AES_CTR_INITIAL_COUNTER.length);
+                ivParameterBuffer.put(salt);
+                ivParameterBuffer.put(mIv);
+                ivParameterBuffer.put(AES_CTR_INITIAL_COUNTER);
+                ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array());
+            } else {
+                throw new IllegalArgumentException("Invalid algorithm " + algoName);
+            }
+
+            // Encrypt payload
+            final Cipher cipher = Cipher.getInstance(algoName);
+            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
+            final byte[] encrypted =
+                    cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize));
+
+            // Build ciphertext
+            final ByteBuffer cipherText = ByteBuffer.allocate(mIv.length + encrypted.length);
+            cipherText.put(mIv);
+            cipherText.put(encrypted);
+
+            return getByteArrayFromBuffer(cipherText);
+        }
+    }
+
+    public static final class EspAeadCipher extends EspCipher {
+        public final int icvLen;
+
+        public EspAeadCipher(
+                String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen) {
+            super(algoName, blockSize, key, ivLen, saltLen);
+            this.icvLen = icvLen;
+        }
+
+        @Override
+        public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
+                throws GeneralSecurityException {
+            // Provided key consists of encryption/decryption key plus salt. Salt is used
+            // with ESP payload IV to build IvParameterSpec.
+            final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen);
+            final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length);
+
+            final SecretKeySpec secretKeySpec = new SecretKeySpec(secretKey, algoName);
+
+            final ByteBuffer ivParameterBuffer = ByteBuffer.allocate(saltLen + mIv.length);
+            ivParameterBuffer.put(salt);
+            ivParameterBuffer.put(mIv);
+            final IvParameterSpec ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array());
+
+            final ByteBuffer aadBuffer = ByteBuffer.allocate(ESP_HDRLEN);
+            aadBuffer.putInt(spi);
+            aadBuffer.putInt(seqNum);
+
+            // Encrypt payload
+            final Cipher cipher = Cipher.getInstance(algoName);
+            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
+            cipher.updateAAD(aadBuffer.array());
+            final byte[] encryptedTextAndIcv =
+                    cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize));
+
+            // Build ciphertext
+            final ByteBuffer cipherText =
+                    ByteBuffer.allocate(mIv.length + encryptedTextAndIcv.length);
+            cipherText.put(mIv);
+            cipherText.put(encryptedTextAndIcv);
+
+            return getByteArrayFromBuffer(cipherText);
+        }
+    }
+
+    public static class EspAuth {
+        public final String algoName;
+        public final byte[] key;
+        public final int icvLen;
+
+        private static final Set<String> JCE_SUPPORTED_MACS = new ArraySet<>();
+
+        static {
+            JCE_SUPPORTED_MACS.add(HMAC_MD5);
+            JCE_SUPPORTED_MACS.add(HMAC_SHA1);
+            JCE_SUPPORTED_MACS.add(HMAC_SHA_256);
+            JCE_SUPPORTED_MACS.add(HMAC_SHA_384);
+            JCE_SUPPORTED_MACS.add(HMAC_SHA_512);
+            JCE_SUPPORTED_MACS.add(AES_CMAC);
+        }
+
+        public EspAuth(String algoName, byte[] key, int icvLen) {
+            this.algoName = algoName;
+            this.key = key;
+            this.icvLen = icvLen;
+        }
+
+        public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
+            if (AES_XCBC.equals(algoName)) {
+                final Cipher aesCipher = Cipher.getInstance(AES_CBC);
+                return new AesXCbcImpl().mac(key, authenticatedSection, true /* needTruncation */);
+            } else if (JCE_SUPPORTED_MACS.contains(algoName)) {
+                final Mac mac = Mac.getInstance(algoName);
+                final SecretKeySpec authKey = new SecretKeySpec(key, algoName);
+                mac.init(authKey);
+
+                final ByteBuffer buffer = ByteBuffer.wrap(mac.doFinal(authenticatedSection));
+                final byte[] icv = new byte[icvLen];
+                buffer.get(icv);
+                return icv;
+            } else {
+                throw new IllegalArgumentException("Invalid algorithm: " + algoName);
+            }
+        }
+    }
+
+    public static final class EspAuthNull extends EspAuth {
+        private static final String AUTH_NULL = "AUTH_NULL";
+        private static final int ICV_LEN_UNUSED = 0;
+        private static final byte[] KEY_UNUSED = new byte[0];
+        private static final byte[] ICV_EMPTY = new byte[0];
+
+        private static final EspAuthNull sInstance = new EspAuthNull();
+
+        private EspAuthNull() {
+            super(AUTH_NULL, KEY_UNUSED, ICV_LEN_UNUSED);
+        }
+
+        public static EspAuthNull getInstance() {
+            return sInstance;
+        }
+
+        @Override
+        public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
+            return ICV_EMPTY;
+        }
+    }
+
     /*
      * Debug printing
      */
diff --git a/tests/cts/net/src/android/net/cts/TestNetworkRunnable.java b/tests/cts/net/src/android/net/cts/TestNetworkRunnable.java
new file mode 100644
index 0000000..0eb5644
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/TestNetworkRunnable.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts;
+
+import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
+
+import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
+
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.LinkAddress;
+import android.net.Network;
+import android.net.TestNetworkInterface;
+import android.net.TestNetworkManager;
+import android.net.cts.util.CtsNetUtils;
+
+import androidx.test.InstrumentationRegistry;
+
+import com.android.compatibility.common.util.ThrowingRunnable;
+
+import java.net.Inet4Address;
+import java.net.InetAddress;
+
+/** This class supports running a test with a test network. */
+public class TestNetworkRunnable implements ThrowingRunnable {
+    private static final int IP4_PREFIX_LEN = 32;
+    private static final int IP6_PREFIX_LEN = 128;
+
+    private static final InetAddress DEFAULT_ADDRESS_4 =
+            InetAddress.parseNumericAddress("192.2.0.2");
+    private static final InetAddress DEFAULT_ADDRESS_6 =
+            InetAddress.parseNumericAddress("2001:db8:1::2");
+
+    private static final Context sContext = InstrumentationRegistry.getContext();
+    private static final ConnectivityManager sCm =
+            sContext.getSystemService(ConnectivityManager.class);
+
+    private final Test mTest;
+
+    public TestNetworkRunnable(Test test) {
+        mTest = test;
+    }
+
+    private void runTest() throws Exception {
+        final TestNetworkManager tnm = sContext.getSystemService(TestNetworkManager.class);
+
+        // Non-final; these variables ensure we clean up properly after our test if we
+        // have allocated test network resources
+        TestNetworkInterface testIface = null;
+        TestNetworkCallback tunNetworkCallback = null;
+
+        final CtsNetUtils ctsNetUtils = new CtsNetUtils(sContext);
+        final InetAddress[] addresses = mTest.getTestNetworkAddresses();
+        final LinkAddress[] linkAddresses = new LinkAddress[addresses.length];
+        for (int i = 0; i < addresses.length; i++) {
+            InetAddress address = addresses[i];
+            if (address instanceof Inet4Address) {
+                linkAddresses[i] = new LinkAddress(address, IP4_PREFIX_LEN);
+            } else {
+                linkAddresses[i] = new LinkAddress(address, IP6_PREFIX_LEN);
+            }
+        }
+
+        try {
+            // Build underlying test network
+            testIface = tnm.createTunInterface(linkAddresses);
+
+            // Hold on to this callback to ensure network does not get reaped.
+            tunNetworkCallback = ctsNetUtils.setupAndGetTestNetwork(testIface.getInterfaceName());
+
+            mTest.runTest(testIface, tunNetworkCallback);
+        } finally {
+            try {
+                mTest.cleanupTest();
+            } catch (Exception e) {
+                // No action
+            }
+
+            if (testIface != null) {
+                testIface.getFileDescriptor().close();
+            }
+
+            if (tunNetworkCallback != null) {
+                sCm.unregisterNetworkCallback(tunNetworkCallback);
+            }
+
+            final Network testNetwork = tunNetworkCallback.currentNetwork;
+            if (testNetwork != null) {
+                tnm.teardownTestNetwork(testNetwork);
+            }
+        }
+    }
+
+    @Override
+    public void run() throws Exception {
+        if (sContext.checkSelfPermission(MANAGE_TEST_NETWORKS) == PERMISSION_GRANTED) {
+            runTest();
+        } else {
+            runWithShellPermissionIdentity(this::runTest, MANAGE_TEST_NETWORKS);
+        }
+    }
+
+    /** Interface for test caller to configure the test that will be run with a test network */
+    public interface Test {
+        /** Runs the test with a test network */
+        void runTest(TestNetworkInterface testIface, TestNetworkCallback tunNetworkCallback)
+                throws Exception;
+
+        /** Cleans up when the test is finished or interrupted */
+        void cleanupTest();
+
+        /** Returns the IP addresses that will be used by the test network */
+        default InetAddress[] getTestNetworkAddresses() {
+            return new InetAddress[] {DEFAULT_ADDRESS_4, DEFAULT_ADDRESS_6};
+        }
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/TunUtils.java b/tests/cts/net/src/android/net/cts/TunUtils.java
index 7887385..d8e39b4 100644
--- a/tests/cts/net/src/android/net/cts/TunUtils.java
+++ b/tests/cts/net/src/android/net/cts/TunUtils.java
@@ -147,6 +147,10 @@
         return espPkt; // We've found the packet we're looking for.
     }
 
+    public byte[] awaitEspPacket(int spi, boolean useEncap) throws Exception {
+        return awaitPacket((pkt) -> isEsp(pkt, spi, useEncap));
+    }
+
     private static boolean isSpiEqual(byte[] pkt, int espOffset, int spi) {
         // Check SPI byte by byte.
         return pkt[espOffset] == (byte) ((spi >>> 24) & 0xff)