am 56935417: am f03397fa: Merge changes If5359c26,I5d09be41

* commit '56935417ceac36ea411b1489e08fec64d188bd79':
  Use a function instead of a macro.
  Mark sockets on accept().
diff --git a/libnetd_client/NetdClient.cpp b/libnetd_client/NetdClient.cpp
index 1d8501a..8deea1e 100644
--- a/libnetd_client/NetdClient.cpp
+++ b/libnetd_client/NetdClient.cpp
@@ -18,12 +18,22 @@
 #include "netd_client/FwmarkCommands.h"
 
 #include <sys/socket.h>
+#include <unistd.h>
 
 namespace {
 
+int closeFdAndRestoreErrno(int fd) {
+    int error = errno;
+    close(fd);
+    errno = error;
+    return -1;
+}
+
 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
+typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*);
 
 ConnectFunctionType libcConnect = 0;
+AcceptFunctionType libcAccept = 0;
 
 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
     if (FwmarkClient::shouldSetFwmark(sockfd, addr)) {
@@ -35,6 +45,28 @@
     return libcConnect(sockfd, addr, addrlen);
 }
 
+int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) {
+    int acceptedSocket = libcAccept(sockfd, addr, addrlen);
+    if (acceptedSocket == -1) {
+        return -1;
+    }
+    sockaddr socketAddress;
+    if (!addr) {
+        socklen_t socketAddressLen = sizeof(socketAddress);
+        if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) {
+            return closeFdAndRestoreErrno(acceptedSocket);
+        }
+        addr = &socketAddress;
+    }
+    if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) {
+        char data[] = {FWMARK_COMMAND_ON_ACCEPT};
+        if (!FwmarkClient().send(data, sizeof(data), acceptedSocket)) {
+            return closeFdAndRestoreErrno(acceptedSocket);
+        }
+    }
+    return acceptedSocket;
+}
+
 }  // namespace
 
 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
@@ -43,3 +75,10 @@
         *function = netdClientConnect;
     }
 }
+
+extern "C" void netdClientInitAccept(AcceptFunctionType* function) {
+    if (function && *function) {
+        libcAccept = *function;
+        *function = netdClientAccept;
+    }
+}