Refresh conntrack entry timeout

Needed because the BPF maps offloads IPv4 traffic. The kernel can't
trace the offloaded traffic to keep the conntrack entry.

Bug: 190783768
Test: atest TetheringCoverageTests
Original-Change: https://android-review.googlesource.com/1566871
Merged-In: Idbcf686c9b2124b192944156ac5111be741744fb
Change-Id: Idbcf686c9b2124b192944156ac5111be741744fb
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 24b9234..9b95dac 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -41,7 +41,9 @@
 import android.net.ip.ConntrackMonitor;
 import android.net.ip.ConntrackMonitor.ConntrackEventConsumer;
 import android.net.ip.IpServer;
+import android.net.netlink.ConntrackMessage;
 import android.net.netlink.NetlinkConstants;
+import android.net.netlink.NetlinkSocket;
 import android.net.netstats.provider.NetworkStatsProvider;
 import android.net.util.InterfaceParams;
 import android.net.util.SharedLog;
@@ -49,6 +51,7 @@
 import android.os.Handler;
 import android.os.SystemClock;
 import android.system.ErrnoException;
+import android.system.OsConstants;
 import android.text.TextUtils;
 import android.util.ArraySet;
 import android.util.Log;
@@ -121,6 +124,13 @@
     }
 
     @VisibleForTesting
+    static final int POLLING_CONNTRACK_TIMEOUT_MS = 60_000;
+    @VisibleForTesting
+    static final int NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED = 432000;
+    @VisibleForTesting
+    static final int NF_CONNTRACK_UDP_TIMEOUT_STREAM = 180;
+
+    @VisibleForTesting
     enum StatsType {
         STATS_PER_IFACE,
         STATS_PER_UID,
@@ -234,11 +244,17 @@
     private int mLastIPv4UpstreamIfindex = 0;
 
     // Runnable that used by scheduling next polling of stats.
-    private final Runnable mScheduledPollingTask = () -> {
+    private final Runnable mScheduledPollingStats = () -> {
         updateForwardedStats();
         maybeSchedulePollingStats();
     };
 
+    // Runnable that used by scheduling next polling of conntrack timeout.
+    private final Runnable mScheduledPollingConntrackTimeout = () -> {
+        maybeRefreshConntrackTimeout();
+        maybeSchedulePollingConntrackTimeout();
+    };
+
     // TODO: add BpfMap<TetherDownstream64Key, TetherDownstream64Value> retrieving function.
     @VisibleForTesting
     public abstract static class Dependencies {
@@ -268,13 +284,19 @@
         }
 
         /**
+         * Represents an estimate of elapsed time since boot in nanoseconds.
+         */
+        public long elapsedRealtimeNanos() {
+            return SystemClock.elapsedRealtimeNanos();
+        }
+
+        /**
          * Check OS Build at least S.
          *
          * TODO: move to BpfCoordinatorShim once the test doesn't need the mocked OS build for
          * testing different code flows concurrently.
          */
         public boolean isAtLeastS() {
-            // TODO: consider using ShimUtils.isAtLeastS.
             return SdkLevel.isAtLeastS();
         }
 
@@ -412,6 +434,7 @@
 
         mPollingStarted = true;
         maybeSchedulePollingStats();
+        maybeSchedulePollingConntrackTimeout();
 
         mLog.i("Polling started");
     }
@@ -427,9 +450,13 @@
     public void stopPolling() {
         if (!mPollingStarted) return;
 
-        // Stop scheduled polling tasks and poll the latest stats from BPF maps.
-        if (mHandler.hasCallbacks(mScheduledPollingTask)) {
-            mHandler.removeCallbacks(mScheduledPollingTask);
+        // Stop scheduled polling conntrack timeout.
+        if (mHandler.hasCallbacks(mScheduledPollingConntrackTimeout)) {
+            mHandler.removeCallbacks(mScheduledPollingConntrackTimeout);
+        }
+        // Stop scheduled polling stats and poll the latest stats from BPF maps.
+        if (mHandler.hasCallbacks(mScheduledPollingStats)) {
+            mHandler.removeCallbacks(mScheduledPollingStats);
         }
         updateForwardedStats();
         mPollingStarted = false;
@@ -1412,12 +1439,86 @@
         return addr6;
     }
 
-    // Support raw ip only.
-    // TODO: add ether ip support.
+    @Nullable
+    private Inet4Address ipv4MappedAddressBytesToIpv4Address(final byte[] addr46) {
+        if (addr46.length != 16) return null;
+        if (addr46[0] != 0 || addr46[1] != 0 || addr46[2] != 0 || addr46[3] != 0
+                || addr46[4] != 0 || addr46[5] != 0 || addr46[6] != 0 || addr46[7] != 0
+                || addr46[8] != 0 && addr46[9] != 0 || (addr46[10] & 0xff) != 0xff
+                || (addr46[11] & 0xff) != 0xff) {
+            return null;
+        }
+
+        final byte[] addr4 = new byte[4];
+        addr4[0] = addr46[12];
+        addr4[1] = addr46[13];
+        addr4[2] = addr46[14];
+        addr4[3] = addr46[15];
+
+        return parseIPv4Address(addr4);
+    }
+
     // TODO: parse CTA_PROTOINFO of conntrack event in ConntrackMonitor. For TCP, only add rules
     // while TCP status is established.
     @VisibleForTesting
     class BpfConntrackEventConsumer implements ConntrackEventConsumer {
+        // The upstream4 and downstream4 rules are built as the following tables. Only raw ip
+        // upstream interface is supported. Note that the field "lastUsed" is only updated by
+        // BPF program which records the last used time for a given rule.
+        // TODO: support ether ip upstream interface.
+        //
+        // NAT network topology:
+        //
+        //         public network (rawip)                 private network
+        //                   |                 UE                |
+        // +------------+    V    +------------+------------+    V    +------------+
+        // |   Sever    +---------+  Upstream  | Downstream +---------+   Client   |
+        // +------------+         +------------+------------+         +------------+
+        //
+        // upstream4 key and value:
+        //
+        // +------+------------------------------------------------+
+        // |      |      TetherUpstream4Key                        |
+        // +------+------+------+------+------+------+------+------+
+        // |field |iif   |dstMac|l4prot|src4  |dst4  |srcPor|dstPor|
+        // |      |      |      |o     |      |      |t     |t     |
+        // +------+------+------+------+------+------+------+------+
+        // |value |downst|downst|tcp/  |client|server|client|server|
+        // |      |ream  |ream  |udp   |      |      |      |      |
+        // +------+------+------+------+------+------+------+------+
+        //
+        // +------+---------------------------------------------------------------------+
+        // |      |      TetherUpstream4Value                                           |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        // |field |oif   |ethDst|ethSrc|ethPro|pmtu  |src46 |dst46 |srcPor|dstPor|lastUs|
+        // |      |      |mac   |mac   |to    |      |      |      |t     |t     |ed    |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        // |value |upstre|--    |--    |ETH_P_|1500  |upstre|server|upstre|server|--    |
+        // |      |am    |      |      |IP    |      |am    |      |am    |      |      |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        //
+        // downstream4 key and value:
+        //
+        // +------+------------------------------------------------+
+        // |      |      TetherDownstream4Key                      |
+        // +------+------+------+------+------+------+------+------+
+        // |field |iif   |dstMac|l4prot|src4  |dst4  |srcPor|dstPor|
+        // |      |      |      |o     |      |      |t     |t     |
+        // +------+------+------+------+------+------+------+------+
+        // |value |upstre|--    |tcp/  |server|upstre|server|upstre|
+        // |      |am    |      |udp   |      |am    |      |am    |
+        // +------+------+------+------+------+------+------+------+
+        //
+        // +------+---------------------------------------------------------------------+
+        // |      |      TetherDownstream4Value                                         |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        // |field |oif   |ethDst|ethSrc|ethPro|pmtu  |src46 |dst46 |srcPor|dstPor|lastUs|
+        // |      |      |mac   |mac   |to    |      |      |      |t     |t     |ed    |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        // |value |downst|client|downst|ETH_P_|1500  |server|client|server|client|--    |
+        // |      |ream  |      |ream  |IP    |      |      |      |      |      |      |
+        // +------+------+------+------+------+------+------+------+------+------+------+
+        //
         @NonNull
         private Tether4Key makeTetherUpstream4Key(
                 @NonNull ConntrackEvent e, @NonNull ClientInfo c) {
@@ -1751,14 +1852,89 @@
         return Math.max(DEFAULT_TETHER_OFFLOAD_POLL_INTERVAL_MS, configInterval);
     }
 
+    @Nullable
+    private Inet4Address parseIPv4Address(byte[] addrBytes) {
+        try {
+            final InetAddress ia = Inet4Address.getByAddress(addrBytes);
+            if (ia instanceof Inet4Address) return (Inet4Address) ia;
+        } catch (UnknownHostException | IllegalArgumentException e) {
+            mLog.e("Failed to parse IPv4 address: " + e);
+        }
+        return null;
+    }
+
+    // Update CTA_TUPLE_ORIG timeout for a given conntrack entry. Note that there will also be
+    // coming a conntrack event to notify updated timeout.
+    private void updateConntrackTimeout(byte proto, Inet4Address src4, short srcPort,
+            Inet4Address dst4, short dstPort) {
+        if (src4 == null || dst4 == null) return;
+
+        // TODO: consider acquiring the timeout setting from nf_conntrack_* variables.
+        // - proc/sys/net/netfilter/nf_conntrack_tcp_timeout_established
+        // - proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream
+        // See kernel document nf_conntrack-sysctl.txt.
+        final int timeoutSec = (proto == OsConstants.IPPROTO_TCP)
+                ? NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED
+                : NF_CONNTRACK_UDP_TIMEOUT_STREAM;
+        final byte[] msg = ConntrackMessage.newIPv4TimeoutUpdateRequest(
+                proto, src4, (int) srcPort, dst4, (int) dstPort, timeoutSec);
+        try {
+            NetlinkSocket.sendOneShotKernelMessage(OsConstants.NETLINK_NETFILTER, msg);
+        } catch (ErrnoException e) {
+            mLog.e("Error updating conntrack entry ("
+                    + "proto: " + proto + ", "
+                    + "src4: " + src4 + ", "
+                    + "srcPort: " + Short.toUnsignedInt(srcPort) + ", "
+                    + "dst4: " + dst4 + ", "
+                    + "dstPort: " + Short.toUnsignedInt(dstPort) + "), "
+                    + "msg: " + NetlinkConstants.hexify(msg) + ", "
+                    + "e: " + e);
+        }
+    }
+
+    private void maybeRefreshConntrackTimeout() {
+        final long now = mDeps.elapsedRealtimeNanos();
+
+        // Reverse the source and destination {address, port} from downstream value because
+        // #updateConntrackTimeout refresh the timeout of netlink attribute CTA_TUPLE_ORIG
+        // which is opposite direction for downstream map value.
+        mBpfCoordinatorShim.tetherOffloadRuleForEach(DOWNSTREAM, (k, v) -> {
+            if ((now - v.lastUsed) / 1_000_000 < POLLING_CONNTRACK_TIMEOUT_MS) {
+                updateConntrackTimeout((byte) k.l4proto,
+                        ipv4MappedAddressBytesToIpv4Address(v.dst46), (short) v.dstPort,
+                        ipv4MappedAddressBytesToIpv4Address(v.src46), (short) v.srcPort);
+            }
+        });
+
+        // TODO: Consider ignoring TCP traffic on upstream and monitor on downstream only
+        // because TCP is a bidirectional traffic. Probably don't need to extend timeout by
+        // both directions for TCP.
+        mBpfCoordinatorShim.tetherOffloadRuleForEach(UPSTREAM, (k, v) -> {
+            if ((now - v.lastUsed) / 1_000_000 < POLLING_CONNTRACK_TIMEOUT_MS) {
+                updateConntrackTimeout((byte) k.l4proto, parseIPv4Address(k.src4),
+                        (short) k.srcPort, parseIPv4Address(k.dst4), (short) k.dstPort);
+            }
+        });
+    }
+
     private void maybeSchedulePollingStats() {
         if (!mPollingStarted) return;
 
-        if (mHandler.hasCallbacks(mScheduledPollingTask)) {
-            mHandler.removeCallbacks(mScheduledPollingTask);
+        if (mHandler.hasCallbacks(mScheduledPollingStats)) {
+            mHandler.removeCallbacks(mScheduledPollingStats);
         }
 
-        mHandler.postDelayed(mScheduledPollingTask, getPollingInterval());
+        mHandler.postDelayed(mScheduledPollingStats, getPollingInterval());
+    }
+
+    private void maybeSchedulePollingConntrackTimeout() {
+        if (!mPollingStarted) return;
+
+        if (mHandler.hasCallbacks(mScheduledPollingConntrackTimeout)) {
+            mHandler.removeCallbacks(mScheduledPollingConntrackTimeout);
+        }
+
+        mHandler.postDelayed(mScheduledPollingConntrackTimeout, POLLING_CONNTRACK_TIMEOUT_MS);
     }
 
     // Return forwarding rule map. This is used for testing only.
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index 073ca89..914e0d4 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -36,9 +36,13 @@
 import static android.system.OsConstants.ETH_P_IPV6;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
+import static android.system.OsConstants.NETLINK_NETFILTER;
 
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.staticMockMarker;
+import static com.android.networkstack.tethering.BpfCoordinator.NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED;
+import static com.android.networkstack.tethering.BpfCoordinator.NF_CONNTRACK_UDP_TIMEOUT_STREAM;
+import static com.android.networkstack.tethering.BpfCoordinator.POLLING_CONNTRACK_TIMEOUT_MS;
 import static com.android.networkstack.tethering.BpfCoordinator.StatsType;
 import static com.android.networkstack.tethering.BpfCoordinator.StatsType.STATS_PER_IFACE;
 import static com.android.networkstack.tethering.BpfCoordinator.StatsType.STATS_PER_UID;
@@ -78,7 +82,9 @@
 import android.net.ip.ConntrackMonitor;
 import android.net.ip.ConntrackMonitor.ConntrackEventConsumer;
 import android.net.ip.IpServer;
+import android.net.netlink.ConntrackMessage;
 import android.net.netlink.NetlinkConstants;
+import android.net.netlink.NetlinkSocket;
 import android.net.util.InterfaceParams;
 import android.net.util.SharedLog;
 import android.os.Build;
@@ -221,6 +227,7 @@
     // it has to access the non-static function of BPF coordinator.
     private BpfConntrackEventConsumer mConsumer;
 
+    private long mElapsedRealtimeNanos = 0;
     private final ArgumentCaptor<ArrayList> mStringArrayCaptor =
             ArgumentCaptor.forClass(ArrayList.class);
     private final TestLooper mTestLooper = new TestLooper();
@@ -260,6 +267,10 @@
                         return mConntrackMonitor;
                     }
 
+                    public long elapsedRealtimeNanos() {
+                        return mElapsedRealtimeNanos;
+                    }
+
                     @Nullable
                     public BpfMap<Tether4Key, Tether4Value> getBpfDownstream4Map() {
                         return mBpfDownstream4Map;
@@ -1344,6 +1355,11 @@
     }
 
     @NonNull
+    private Tether4Key makeDownstream4Key() {
+        return makeDownstream4Key(IPPROTO_TCP);
+    }
+
+    @NonNull
     private ConntrackEvent makeTestConntrackEvent(short msgType, int proto) {
         if (msgType != IPCTNL_MSG_CT_NEW && msgType != IPCTNL_MSG_CT_DELETE) {
             fail("Not support message type " + msgType);
@@ -1504,4 +1520,104 @@
         mConsumer.accept(makeTestConntrackEvent(IPCTNL_MSG_CT_NEW, IPPROTO_UDP));
         verify(mBpfDevMap, never()).updateEntry(any(), any());
     }
+
+    private void setElapsedRealtimeNanos(long nanoSec) {
+        mElapsedRealtimeNanos = nanoSec;
+    }
+
+    private void checkRefreshConntrackTimeout(final TestBpfMap<Tether4Key, Tether4Value> bpfMap,
+            final Tether4Key tcpKey, final Tether4Value tcpValue, final Tether4Key udpKey,
+            final Tether4Value udpValue) throws Exception {
+        // Both system elapsed time since boot and the rule last used time are used to measure
+        // the rule expiration. In this test, all test rules are fixed the last used time to 0.
+        // Set the different testing elapsed time to make the rule to be valid or expired.
+        //
+        // Timeline:
+        // 0                                       60 (seconds)
+        // +---+---+---+---+--...--+---+---+---+---+---+- ..
+        // |      POLLING_CONNTRACK_TIMEOUT_MS     |
+        // +---+---+---+---+--...--+---+---+---+---+---+- ..
+        // |<-          valid diff           ->|
+        // |<-          expired diff                 ->|
+        // ^                                   ^       ^
+        // last used time      elapsed time (valid)    elapsed time (expired)
+        final long validTime = (POLLING_CONNTRACK_TIMEOUT_MS - 1) * 1_000_000L;
+        final long expiredTime = (POLLING_CONNTRACK_TIMEOUT_MS + 1) * 1_000_000L;
+
+        // Static mocking for NetlinkSocket.
+        MockitoSession mockSession = ExtendedMockito.mockitoSession()
+                .mockStatic(NetlinkSocket.class)
+                .startMocking();
+        try {
+            final BpfCoordinator coordinator = makeBpfCoordinator();
+            coordinator.startPolling();
+            bpfMap.insertEntry(tcpKey, tcpValue);
+            bpfMap.insertEntry(udpKey, udpValue);
+
+            // [1] Don't refresh contrack timeout.
+            setElapsedRealtimeNanos(expiredTime);
+            mTestLooper.moveTimeForward(POLLING_CONNTRACK_TIMEOUT_MS);
+            waitForIdle();
+            ExtendedMockito.verifyNoMoreInteractions(staticMockMarker(NetlinkSocket.class));
+            ExtendedMockito.clearInvocations(staticMockMarker(NetlinkSocket.class));
+
+            // [2] Refresh contrack timeout.
+            setElapsedRealtimeNanos(validTime);
+            mTestLooper.moveTimeForward(POLLING_CONNTRACK_TIMEOUT_MS);
+            waitForIdle();
+            final byte[] expectedNetlinkTcp = ConntrackMessage.newIPv4TimeoutUpdateRequest(
+                    IPPROTO_TCP, PRIVATE_ADDR, (int) PRIVATE_PORT, REMOTE_ADDR,
+                    (int) REMOTE_PORT, NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED);
+            final byte[] expectedNetlinkUdp = ConntrackMessage.newIPv4TimeoutUpdateRequest(
+                    IPPROTO_UDP, PRIVATE_ADDR, (int) PRIVATE_PORT, REMOTE_ADDR,
+                    (int) REMOTE_PORT, NF_CONNTRACK_UDP_TIMEOUT_STREAM);
+            ExtendedMockito.verify(() -> NetlinkSocket.sendOneShotKernelMessage(
+                    eq(NETLINK_NETFILTER), eq(expectedNetlinkTcp)));
+            ExtendedMockito.verify(() -> NetlinkSocket.sendOneShotKernelMessage(
+                    eq(NETLINK_NETFILTER), eq(expectedNetlinkUdp)));
+            ExtendedMockito.verifyNoMoreInteractions(staticMockMarker(NetlinkSocket.class));
+            ExtendedMockito.clearInvocations(staticMockMarker(NetlinkSocket.class));
+
+            // [3] Don't refresh contrack timeout if polling stopped.
+            coordinator.stopPolling();
+            mTestLooper.moveTimeForward(POLLING_CONNTRACK_TIMEOUT_MS);
+            waitForIdle();
+            ExtendedMockito.verifyNoMoreInteractions(staticMockMarker(NetlinkSocket.class));
+            ExtendedMockito.clearInvocations(staticMockMarker(NetlinkSocket.class));
+        } finally {
+            mockSession.finishMocking();
+        }
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testRefreshConntrackTimeout_Upstream4Map() throws Exception {
+        // TODO: Replace the dependencies BPF map with a non-mocked TestBpfMap object.
+        final TestBpfMap<Tether4Key, Tether4Value> bpfUpstream4Map =
+                new TestBpfMap<>(Tether4Key.class, Tether4Value.class);
+        doReturn(bpfUpstream4Map).when(mDeps).getBpfUpstream4Map();
+
+        final Tether4Key tcpKey = makeUpstream4Key(IPPROTO_TCP);
+        final Tether4Key udpKey = makeUpstream4Key(IPPROTO_UDP);
+        final Tether4Value tcpValue = makeUpstream4Value();
+        final Tether4Value udpValue = makeUpstream4Value();
+
+        checkRefreshConntrackTimeout(bpfUpstream4Map, tcpKey, tcpValue, udpKey, udpValue);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testRefreshConntrackTimeout_Downstream4Map() throws Exception {
+        // TODO: Replace the dependencies BPF map with a non-mocked TestBpfMap object.
+        final TestBpfMap<Tether4Key, Tether4Value> bpfDownstream4Map =
+                new TestBpfMap<>(Tether4Key.class, Tether4Value.class);
+        doReturn(bpfDownstream4Map).when(mDeps).getBpfDownstream4Map();
+
+        final Tether4Key tcpKey = makeDownstream4Key(IPPROTO_TCP);
+        final Tether4Key udpKey = makeDownstream4Key(IPPROTO_UDP);
+        final Tether4Value tcpValue = makeDownstream4Value();
+        final Tether4Value udpValue = makeDownstream4Value();
+
+        checkRefreshConntrackTimeout(bpfDownstream4Map, tcpKey, tcpValue, udpKey, udpValue);
+    }
 }