Mark sockets on accept().

Change-Id: I5d09be413cf720fbed905f96313b007997ada76c
diff --git a/libnetd_client/NetdClient.cpp b/libnetd_client/NetdClient.cpp
index 1d8501a..f7e8cc2 100644
--- a/libnetd_client/NetdClient.cpp
+++ b/libnetd_client/NetdClient.cpp
@@ -18,12 +18,22 @@
 #include "netd_client/FwmarkCommands.h"
 
 #include <sys/socket.h>
+#include <unistd.h>
+
+#define CLOSE_FD_AND_RESTORE_ERRNO(fd) \
+    do { \
+        int error = errno; \
+        close(fd); \
+        errno = error; \
+    } while (0)
 
 namespace {
 
 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
+typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*);
 
 ConnectFunctionType libcConnect = 0;
+AcceptFunctionType libcAccept = 0;
 
 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
     if (FwmarkClient::shouldSetFwmark(sockfd, addr)) {
@@ -35,6 +45,30 @@
     return libcConnect(sockfd, addr, addrlen);
 }
 
+int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) {
+    int acceptedSocket = libcAccept(sockfd, addr, addrlen);
+    if (acceptedSocket == -1) {
+        return -1;
+    }
+    sockaddr socketAddress;
+    if (!addr) {
+        socklen_t socketAddressLen = sizeof(socketAddress);
+        if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) {
+            CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket);
+            return -1;
+        }
+        addr = &socketAddress;
+    }
+    if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) {
+        char data[] = {FWMARK_COMMAND_ON_ACCEPT};
+        if (!FwmarkClient().send(data, sizeof(data), acceptedSocket)) {
+            CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket);
+            return -1;
+        }
+    }
+    return acceptedSocket;
+}
+
 }  // namespace
 
 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
@@ -43,3 +77,10 @@
         *function = netdClientConnect;
     }
 }
+
+extern "C" void netdClientInitAccept(AcceptFunctionType* function) {
+    if (function && *function) {
+        libcAccept = *function;
+        *function = netdClientAccept;
+    }
+}