Merge "Fix TetheringPrivilegedTests crash"
diff --git a/service/src/com/android/server/connectivity/TcpKeepaliveController.java b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
index c480594..acfbb3c 100644
--- a/service/src/com/android/server/connectivity/TcpKeepaliveController.java
+++ b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
@@ -16,6 +16,7 @@
 package com.android.server.connectivity;
 
 import static android.net.SocketKeepalive.DATA_RECEIVED;
+import static android.net.SocketKeepalive.ERROR_INVALID_IP_ADDRESS;
 import static android.net.SocketKeepalive.ERROR_INVALID_SOCKET;
 import static android.net.SocketKeepalive.ERROR_SOCKET_NOT_IDLE;
 import static android.net.SocketKeepalive.ERROR_UNSUPPORTED;
@@ -29,6 +30,8 @@
 import static android.system.OsConstants.IP_TTL;
 import static android.system.OsConstants.TIOCOUTQ;
 
+import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
+
 import android.annotation.NonNull;
 import android.net.InvalidPacketException;
 import android.net.NetworkUtils;
@@ -36,7 +39,6 @@
 import android.net.TcpKeepalivePacketData;
 import android.net.TcpKeepalivePacketDataParcelable;
 import android.net.TcpRepairWindow;
-import android.net.util.KeepalivePacketDataUtil;
 import android.os.Handler;
 import android.os.MessageQueue;
 import android.os.Messenger;
@@ -46,12 +48,18 @@
 import android.util.SparseArray;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.IpUtils;
 import com.android.server.connectivity.KeepaliveTracker.KeepaliveInfo;
 
 import java.io.FileDescriptor;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.net.SocketException;
+import java.net.UnknownHostException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 
 /**
  * Manage tcp socket which offloads tcp keepalive.
@@ -82,6 +90,8 @@
 
     private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
 
+    private static final int TCP_HEADER_LENGTH = 20;
+
     // Reference include/uapi/linux/tcp.h
     private static final int TCP_REPAIR = 19;
     private static final int TCP_REPAIR_QUEUE = 20;
@@ -112,12 +122,81 @@
             throws InvalidPacketException, InvalidSocketException {
         try {
             final TcpKeepalivePacketDataParcelable tcpDetails = switchToRepairMode(fd);
-            return KeepalivePacketDataUtil.fromStableParcelable(tcpDetails);
+            // TODO: consider building a TcpKeepalivePacketData directly from switchToRepairMode
+            return fromStableParcelable(tcpDetails);
         } catch (InvalidPacketException | InvalidSocketException e) {
             switchOutOfRepairMode(fd);
             throw e;
         }
     }
+
+    /**
+     * Factory method to create tcp keepalive packet structure.
+     */
+    @VisibleForTesting
+    public static TcpKeepalivePacketData fromStableParcelable(
+            TcpKeepalivePacketDataParcelable tcpDetails) throws InvalidPacketException {
+        final byte[] packet;
+        try {
+            if ((tcpDetails.srcAddress != null) && (tcpDetails.dstAddress != null)
+                    && (tcpDetails.srcAddress.length == 4 /* V4 IP length */)
+                    && (tcpDetails.dstAddress.length == 4 /* V4 IP length */)) {
+                packet = buildV4Packet(tcpDetails);
+            } else {
+                // TODO: support ipv6
+                throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
+            }
+            return new TcpKeepalivePacketData(
+                    InetAddress.getByAddress(tcpDetails.srcAddress),
+                    tcpDetails.srcPort,
+                    InetAddress.getByAddress(tcpDetails.dstAddress),
+                    tcpDetails.dstPort,
+                    packet,
+                    tcpDetails.seq, tcpDetails.ack, tcpDetails.rcvWnd, tcpDetails.rcvWndScale,
+                    tcpDetails.tos, tcpDetails.ttl);
+        } catch (UnknownHostException e) {
+            throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
+        }
+    }
+
+    /**
+     * Build ipv4 tcp keepalive packet, not including the link-layer header.
+     */
+    // TODO : if this code is ever moved to the network stack, factorize constants with the ones
+    // over there.
+    // TODO: consider using Ipv4Utils.buildTcpv4Packet() instead
+    private static byte[] buildV4Packet(TcpKeepalivePacketDataParcelable tcpDetails) {
+        final int length = IPV4_HEADER_MIN_LEN + TCP_HEADER_LENGTH;
+        ByteBuffer buf = ByteBuffer.allocate(length);
+        buf.order(ByteOrder.BIG_ENDIAN);
+        buf.put((byte) 0x45);                       // IP version and IHL
+        buf.put((byte) tcpDetails.tos);             // TOS
+        buf.putShort((short) length);
+        buf.putInt(0x00004000);                     // ID, flags=DF, offset
+        buf.put((byte) tcpDetails.ttl);             // TTL
+        buf.put((byte) IPPROTO_TCP);
+        final int ipChecksumOffset = buf.position();
+        buf.putShort((short) 0);                    // IP checksum
+        buf.put(tcpDetails.srcAddress);
+        buf.put(tcpDetails.dstAddress);
+        buf.putShort((short) tcpDetails.srcPort);
+        buf.putShort((short) tcpDetails.dstPort);
+        buf.putInt(tcpDetails.seq);                 // Sequence Number
+        buf.putInt(tcpDetails.ack);                 // ACK
+        buf.putShort((short) 0x5010);               // TCP length=5, flags=ACK
+        buf.putShort((short) (tcpDetails.rcvWnd >> tcpDetails.rcvWndScale));   // Window size
+        final int tcpChecksumOffset = buf.position();
+        buf.putShort((short) 0);                    // TCP checksum
+        // URG is not set therefore the urgent pointer is zero.
+        buf.putShort((short) 0);                    // Urgent pointer
+
+        buf.putShort(ipChecksumOffset, com.android.net.module.util.IpUtils.ipChecksum(buf, 0));
+        buf.putShort(tcpChecksumOffset, IpUtils.tcpChecksum(
+                buf, 0, IPV4_HEADER_MIN_LEN, TCP_HEADER_LENGTH));
+
+        return buf.array();
+    }
+
     /**
      * Switch the tcp socket to repair mode and query detail tcp information.
      *
diff --git a/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java b/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
index 8498b6f..6afa4e9 100644
--- a/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
+++ b/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
@@ -27,6 +27,7 @@
 import android.os.Build;
 import android.util.Log;
 
+import com.android.server.connectivity.TcpKeepaliveController;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 
@@ -81,7 +82,7 @@
         testInfo.tos = tos;
         testInfo.ttl = ttl;
         try {
-            resultData = KeepalivePacketDataUtil.fromStableParcelable(testInfo);
+            resultData = TcpKeepaliveController.fromStableParcelable(testInfo);
         } catch (InvalidPacketException e) {
             fail("InvalidPacketException: " + e);
         }
@@ -155,7 +156,7 @@
         testInfo.ttl = ttl;
         TcpKeepalivePacketData testData = null;
         TcpKeepalivePacketDataParcelable resultData = null;
-        testData = KeepalivePacketDataUtil.fromStableParcelable(testInfo);
+        testData = TcpKeepaliveController.fromStableParcelable(testInfo);
         resultData = KeepalivePacketDataUtil.toStableParcelable(testData);
         assertArrayEquals(resultData.srcAddress, IPV4_KEEPALIVE_SRC_ADDR);
         assertArrayEquals(resultData.dstAddress, IPV4_KEEPALIVE_DST_ADDR);
@@ -198,11 +199,11 @@
         testParcel.ttl = ttl;
 
         final KeepalivePacketData testData =
-                KeepalivePacketDataUtil.fromStableParcelable(testParcel);
+                TcpKeepaliveController.fromStableParcelable(testParcel);
         final TcpKeepalivePacketDataParcelable parsedParcelable =
                 KeepalivePacketDataUtil.parseTcpKeepalivePacketData(testData);
         final TcpKeepalivePacketData roundTripData =
-                KeepalivePacketDataUtil.fromStableParcelable(parsedParcelable);
+                TcpKeepaliveController.fromStableParcelable(parsedParcelable);
 
         // Generated packet is the same, but rcvWnd / wndScale will differ if scale is non-zero
         assertTrue(testData.getPacket().length > 0);
@@ -210,11 +211,11 @@
 
         testParcel.rcvWndScale = 0;
         final KeepalivePacketData noScaleTestData =
-                KeepalivePacketDataUtil.fromStableParcelable(testParcel);
+                TcpKeepaliveController.fromStableParcelable(testParcel);
         final TcpKeepalivePacketDataParcelable noScaleParsedParcelable =
                 KeepalivePacketDataUtil.parseTcpKeepalivePacketData(noScaleTestData);
         final TcpKeepalivePacketData noScaleRoundTripData =
-                KeepalivePacketDataUtil.fromStableParcelable(noScaleParsedParcelable);
+                TcpKeepaliveController.fromStableParcelable(noScaleParsedParcelable);
         assertEquals(noScaleTestData, noScaleRoundTripData);
         assertTrue(noScaleTestData.getPacket().length > 0);
         assertArrayEquals(noScaleTestData.getPacket(), noScaleRoundTripData.getPacket());