Merge "Add FrameworksNetCommonTests to CTS"
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 1b7d290..4180ea4 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -23,12 +23,17 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
 import static android.provider.Settings.Global.NETWORK_METERED_MULTIPATH_PREFERENCE;
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.AF_UNSPEC;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 
 import android.app.Instrumentation;
 import android.app.PendingIntent;
+import android.app.UiAutomation;
 import android.content.BroadcastReceiver;
 import android.content.ComponentName;
 import android.content.ContentResolver;
@@ -46,8 +51,10 @@
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkInfo.State;
 import android.net.NetworkRequest;
+import android.net.SocketKeepalive;
 import android.net.wifi.WifiManager;
 import android.os.Looper;
+import android.os.MessageQueue;
 import android.os.SystemClock;
 import android.os.SystemProperties;
 import android.provider.Settings;
@@ -64,11 +71,13 @@
 
 import libcore.io.Streams;
 
+import java.io.FileDescriptor;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.OutputStream;
 import java.net.HttpURLConnection;
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -79,6 +88,7 @@
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.regex.Matcher;
@@ -96,7 +106,11 @@
     private static final int HOST_ADDRESS = 0x7f000001;// represent ip 127.0.0.1
     private static final String TEST_HOST = "connectivitycheck.gstatic.com";
     private static final int SOCKET_TIMEOUT_MS = 2000;
+    private static final int CONNECT_TIMEOUT_MS = 2000;
+    private static final int KEEPALIVE_CALLBACK_TIMEOUT_MS = 2000;
+    private static final int KEEPALIVE_SOCKET_TIMEOUT_MS = 5000;
     private static final int SEND_BROADCAST_TIMEOUT = 30000;
+    private static final int MIN_KEEPALIVE_INTERVAL = 10;
     private static final int NETWORK_CHANGE_METEREDNESS_TIMEOUT = 5000;
     private static final int NUM_TRIES_MULTIPATH_PREF_CHECK = 20;
     private static final long INTERVAL_MULTIPATH_PREF_CHECK_MS = 500;
@@ -126,7 +140,8 @@
             new HashMap<Integer, NetworkConfig>();
     boolean mWifiConnectAttempted;
     private TestNetworkCallback mCellNetworkCallback;
-
+    private UiAutomation mUiAutomation;
+    private boolean mShellPermissionIdentityAdopted;
 
     @Override
     protected void setUp() throws Exception {
@@ -153,6 +168,8 @@
                 mNetworks.put(n.type, n);
             } catch (Exception e) {}
         }
+        mUiAutomation = mInstrumentation.getUiAutomation();
+        mShellPermissionIdentityAdopted = false;
     }
 
     @Override
@@ -164,6 +181,7 @@
         if (cellConnectAttempted()) {
             disconnectFromCell();
         }
+        dropShellPermissionIdentity();
         super.tearDown();
     }
 
@@ -1037,4 +1055,226 @@
             setWifiMeteredStatus(ssid, oldMeteredSetting);
         }
     }
+
+    // TODO: move the following socket keep alive test to dedicated test class.
+    /**
+     * Callback used in tcp keepalive offload that allows caller to wait callback fires.
+     */
+    private static class TestSocketKeepaliveCallback extends SocketKeepalive.Callback {
+        public enum CallbackType { ON_STARTED, ON_STOPPED, ON_ERROR };
+
+        public static class CallbackValue {
+            public final CallbackType callbackType;
+            public final int error;
+
+            private CallbackValue(final CallbackType type, final int error) {
+                this.callbackType = type;
+                this.error = error;
+            }
+
+            public static class OnStartedCallback extends CallbackValue {
+                OnStartedCallback() { super(CallbackType.ON_STARTED, 0); }
+            }
+
+            public static class OnStoppedCallback extends CallbackValue {
+                OnStoppedCallback() { super(CallbackType.ON_STOPPED, 0); }
+            }
+
+            public static class OnErrorCallback extends CallbackValue {
+                OnErrorCallback(final int error) { super(CallbackType.ON_ERROR, error); }
+            }
+
+            @Override
+            public boolean equals(Object o) {
+                return o.getClass() == this.getClass()
+                        && this.callbackType == ((CallbackValue) o).callbackType
+                        && this.error == ((CallbackValue) o).error;
+            }
+
+            @Override
+            public String toString() {
+                return String.format("%s(%s, %d)", getClass().getSimpleName(), callbackType, error);
+            }
+        }
+
+        private final LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
+
+        @Override
+        public void onStarted() {
+            mCallbacks.add(new CallbackValue.OnStartedCallback());
+        }
+
+        @Override
+        public void onStopped() {
+            mCallbacks.add(new CallbackValue.OnStoppedCallback());
+        }
+
+        @Override
+        public void onError(final int error) {
+            mCallbacks.add(new CallbackValue.OnErrorCallback(error));
+        }
+
+        public CallbackValue pollCallback() {
+            try {
+                return mCallbacks.poll(KEEPALIVE_CALLBACK_TIMEOUT_MS,
+                        TimeUnit.MILLISECONDS);
+            } catch (InterruptedException e) {
+                fail("Callback not seen after " + KEEPALIVE_CALLBACK_TIMEOUT_MS + " ms");
+            }
+            return null;
+        }
+        private void expectCallback(CallbackValue expectedCallback) {
+            final CallbackValue actualCallback = pollCallback();
+            assertEquals(expectedCallback, actualCallback);
+        }
+
+        public void expectStarted() {
+            expectCallback(new CallbackValue.OnStartedCallback());
+        }
+
+        public void expectStopped() {
+            expectCallback(new CallbackValue.OnStoppedCallback());
+        }
+
+        public void expectError(int error) {
+            expectCallback(new CallbackValue.OnErrorCallback(error));
+        }
+    }
+
+    private InetAddress getAddrByName(final String hostname, final int family) throws Exception {
+        final InetAddress[] allAddrs = InetAddress.getAllByName(hostname);
+        for (InetAddress addr : allAddrs) {
+            if (family == AF_INET && addr instanceof Inet4Address) return addr;
+
+            if (family == AF_INET6 && addr instanceof Inet6Address) return addr;
+
+            if (family == AF_UNSPEC) return addr;
+        }
+        return null;
+    }
+
+    private Socket getConnectedSocket(final Network network, final String host, final int port,
+            final int socketTimeOut, final int family) throws Exception {
+        final Socket s = network.getSocketFactory().createSocket();
+        try {
+            final InetAddress addr = getAddrByName(host, family);
+            if (addr == null) fail("Fail to get destination address for " + family);
+
+            final InetSocketAddress sockAddr = new InetSocketAddress(addr, port);
+            s.setSoTimeout(socketTimeOut);
+            s.connect(sockAddr, CONNECT_TIMEOUT_MS);
+        } catch (Exception e) {
+            s.close();
+            throw e;
+        }
+        return s;
+    }
+
+    private boolean isKeepaliveSupported() throws Exception {
+        final Network network = ensureWifiConnected();
+        final Executor executor = mContext.getMainExecutor();
+        final TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback();
+        try (Socket s = getConnectedSocket(network, TEST_HOST,
+                HTTP_PORT, KEEPALIVE_SOCKET_TIMEOUT_MS, AF_INET);
+                SocketKeepalive sk = mCm.createSocketKeepalive(network, s, executor, callback)) {
+            sk.start(MIN_KEEPALIVE_INTERVAL);
+            final TestSocketKeepaliveCallback.CallbackValue result = callback.pollCallback();
+            switch (result.callbackType) {
+                case ON_STARTED:
+                    sk.stop();
+                    callback.expectStopped();
+                    return true;
+                case ON_ERROR:
+                    if (result.error == SocketKeepalive.ERROR_UNSUPPORTED) return false;
+                    // else fallthrough.
+                default:
+                    fail("Got unexpected callback: " + result);
+                    return false;
+            }
+        }
+    }
+
+    private void adoptShellPermissionIdentity() {
+        mUiAutomation.adoptShellPermissionIdentity();
+        mShellPermissionIdentityAdopted = true;
+    }
+
+    private void dropShellPermissionIdentity() {
+        if (mShellPermissionIdentityAdopted) {
+            mUiAutomation.dropShellPermissionIdentity();
+            mShellPermissionIdentityAdopted = false;
+        }
+    }
+
+    public void testCreateTcpKeepalive() throws Exception {
+        adoptShellPermissionIdentity();
+
+        if (!isKeepaliveSupported()) return;
+
+        final Network network = ensureWifiConnected();
+        final byte[] requestBytes = HTTP_REQUEST.getBytes("UTF-8");
+        // So far only ipv4 tcp keepalive offload is supported.
+        // TODO: add test case for ipv6 tcp keepalive offload when it is supported.
+        try (Socket s = getConnectedSocket(network, TEST_HOST, HTTP_PORT,
+                KEEPALIVE_SOCKET_TIMEOUT_MS, AF_INET)) {
+
+            // Should able to start keep alive offload when socket is idle.
+            final Executor executor = mContext.getMainExecutor();
+            final TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback();
+            try (SocketKeepalive sk = mCm.createSocketKeepalive(network, s, executor, callback)) {
+                sk.start(MIN_KEEPALIVE_INTERVAL);
+                callback.expectStarted();
+
+                // App should not able to write during keepalive offload.
+                final OutputStream out = s.getOutputStream();
+                try {
+                    out.write(requestBytes);
+                    fail("Should not able to write");
+                } catch (IOException e) { }
+                // App should not able to read during keepalive offload.
+                final InputStream in = s.getInputStream();
+                byte[] responseBytes = new byte[4096];
+                try {
+                    in.read(responseBytes);
+                    fail("Should not able to read");
+                } catch (IOException e) { }
+
+                // Stop.
+                sk.stop();
+                callback.expectStopped();
+            }
+
+            // Ensure socket is still connected.
+            assertTrue(s.isConnected());
+            assertFalse(s.isClosed());
+
+            // Let socket be not idle.
+            try {
+                final OutputStream out = s.getOutputStream();
+                out.write(requestBytes);
+            } catch (IOException e) {
+                fail("Failed to write data " + e);
+            }
+            // Make sure response data arrives.
+            final MessageQueue fdHandlerQueue = Looper.getMainLooper().getQueue();
+            final FileDescriptor fd = s.getFileDescriptor$();
+            final CountDownLatch mOnReceiveLatch = new CountDownLatch(1);
+            fdHandlerQueue.addOnFileDescriptorEventListener(fd, EVENT_INPUT, (readyFd, events) -> {
+                mOnReceiveLatch.countDown();
+                return 0; // Unregister listener.
+            });
+            if (!mOnReceiveLatch.await(2, TimeUnit.SECONDS)) {
+                fdHandlerQueue.removeOnFileDescriptorEventListener(fd);
+                fail("Timeout: no response data");
+            }
+
+            // Should get ERROR_SOCKET_NOT_IDLE because there is still data in the receive queue
+            // that has not been read.
+            try (SocketKeepalive sk = mCm.createSocketKeepalive(network, s, executor, callback)) {
+                sk.start(MIN_KEEPALIVE_INTERVAL);
+                callback.expectError(SocketKeepalive.ERROR_SOCKET_NOT_IDLE);
+            }
+
+        }
+    }
 }