Merge "libpdx_uds: Always create channel sockets in the server process" into oc-dev am: 584bc3cebf
am: 063a517cd3

Change-Id: I830c85c2cfb4d1a765d82d2d768c843f7011a207
diff --git a/libs/vr/libpdx_uds/client_channel_factory.cpp b/libs/vr/libpdx_uds/client_channel_factory.cpp
index 850c6d3..433f459 100644
--- a/libs/vr/libpdx_uds/client_channel_factory.cpp
+++ b/libs/vr/libpdx_uds/client_channel_factory.cpp
@@ -60,7 +60,7 @@
 
   bool connected = socket_.IsValid();
   if (!connected) {
-    socket_.Reset(socket(AF_UNIX, SOCK_STREAM, 0));
+    socket_.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
     LOG_ALWAYS_FATAL_IF(
         endpoint_path_.empty(),
         "ClientChannelFactory::Connect: unspecified socket path");
@@ -123,6 +123,15 @@
       connected = true;
       ALOGD("ClientChannelFactory: Connected successfully to %s...",
             remote.sun_path);
+      ChannelConnectionInfo<LocalHandle> connection_info;
+      status = ReceiveData(socket_.Borrow(), &connection_info);
+      if (!status)
+        return status.error_status();
+      socket_ = std::move(connection_info.channel_fd);
+      if (!socket_) {
+        ALOGE("ClientChannelFactory::Connect: Failed to obtain channel socket");
+        return ErrorStatus(EIO);
+      }
     }
     if (use_timeout)
       now = steady_clock::now();
@@ -132,11 +141,11 @@
   InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false);
   status = SendData(socket_.Borrow(), request);
   if (!status)
-    return ErrorStatus(status.error());
+    return status.error_status();
   ResponseHeader<LocalHandle> response;
   status = ReceiveData(socket_.Borrow(), &response);
   if (!status)
-    return ErrorStatus(status.error());
+    return status.error_status();
   int ref = response.ret_code;
   if (ref < 0 || static_cast<size_t>(ref) > response.file_descriptors.size())
     return ErrorStatus(EIO);
diff --git a/libs/vr/libpdx_uds/private/uds/ipc_helper.h b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
index 5b7e5ff..bde16d3 100644
--- a/libs/vr/libpdx_uds/private/uds/ipc_helper.h
+++ b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
@@ -116,6 +116,15 @@
 };
 
 template <typename FileHandleType>
+class ChannelConnectionInfo {
+ public:
+  FileHandleType channel_fd;
+
+ private:
+  PDX_SERIALIZABLE_MEMBERS(ChannelConnectionInfo, channel_fd);
+};
+
+template <typename FileHandleType>
 class RequestHeader {
  public:
   int32_t op{0};
diff --git a/libs/vr/libpdx_uds/private/uds/service_endpoint.h b/libs/vr/libpdx_uds/private/uds/service_endpoint.h
index eb87827..368891c 100644
--- a/libs/vr/libpdx_uds/private/uds/service_endpoint.h
+++ b/libs/vr/libpdx_uds/private/uds/service_endpoint.h
@@ -142,6 +142,8 @@
   BorrowedHandle GetChannelSocketFd(int32_t channel_id);
   BorrowedHandle GetChannelEventFd(int32_t channel_id);
   int32_t GetChannelId(const BorrowedHandle& channel_fd);
+  Status<void> CreateChannelSocketPair(LocalHandle* local_socket,
+                                       LocalHandle* remote_socket);
 
   std::string endpoint_path_;
   bool is_blocking_;
diff --git a/libs/vr/libpdx_uds/service_endpoint.cpp b/libs/vr/libpdx_uds/service_endpoint.cpp
index 6c92259..d96eeff 100644
--- a/libs/vr/libpdx_uds/service_endpoint.cpp
+++ b/libs/vr/libpdx_uds/service_endpoint.cpp
@@ -214,30 +214,42 @@
 
   sockaddr_un remote;
   socklen_t addrlen = sizeof(remote);
-  LocalHandle channel_fd{accept4(socket_fd_.Get(),
-                                 reinterpret_cast<sockaddr*>(&remote), &addrlen,
-                                 SOCK_CLOEXEC)};
-  if (!channel_fd) {
+  LocalHandle connection_fd{accept4(socket_fd_.Get(),
+                                    reinterpret_cast<sockaddr*>(&remote),
+                                    &addrlen, SOCK_CLOEXEC)};
+  if (!connection_fd) {
     ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
           strerror(errno));
     return ErrorStatus(errno);
   }
 
-  int optval = 1;
-  if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
-                 sizeof(optval)) == -1) {
-    ALOGE(
-        "Endpoint::AcceptConnection: Failed to enable the receiving of the "
-        "credentials for channel %d: %s",
-        channel_fd.Get(), strerror(errno));
-    return ErrorStatus(errno);
+  LocalHandle local_socket;
+  LocalHandle remote_socket;
+  auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
+  if (!status)
+    return status;
+
+  // Borrow the local channel handle before we move it into OnNewChannel().
+  BorrowedHandle channel_handle = local_socket.Borrow();
+  status = OnNewChannel(std::move(local_socket));
+  if (!status)
+    return status;
+
+  // Send the channel socket fd to the client.
+  ChannelConnectionInfo<LocalHandle> connection_info;
+  connection_info.channel_fd = std::move(remote_socket);
+  status = SendData(connection_fd.Borrow(), connection_info);
+
+  if (status) {
+    // Get the CHANNEL_OPEN message from client over the channel socket.
+    status = ReceiveMessageForChannel(channel_handle, message);
+  } else {
+    CloseChannel(GetChannelId(channel_handle));
   }
 
-  // Borrow the channel handle before we pass (move) it into OnNewChannel().
-  BorrowedHandle borrowed_channel_handle = channel_fd.Borrow();
-  auto status = OnNewChannel(std::move(channel_fd));
-  if (status)
-    status = ReceiveMessageForChannel(borrowed_channel_handle, message);
+  // Don't need the connection socket anymore. Further communication should
+  // happen over the channel socket.
+  shutdown(connection_fd.Get(), SHUT_WR);
   return status;
 }
 
@@ -349,29 +361,41 @@
   return ErrorStatus{EINVAL};
 }
 
+Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
+                                               LocalHandle* remote_socket) {
+  Status<void> status;
+  int channel_pair[2] = {};
+  if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
+    ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
+          strerror(errno));
+    status.SetError(errno);
+    return status;
+  }
+
+  local_socket->Reset(channel_pair[0]);
+  remote_socket->Reset(channel_pair[1]);
+
+  int optval = 1;
+  if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
+                 sizeof(optval)) == -1) {
+    ALOGE(
+        "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
+        "the credentials for channel %d: %s",
+        local_socket->Get(), strerror(errno));
+    status.SetError(errno);
+  }
+  return status;
+}
+
 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
                                                   int /*flags*/,
                                                   Channel* channel,
                                                   int* channel_id) {
-  int channel_pair[2] = {};
-  if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
-    ALOGE("Endpoint::PushChannel: Failed to create a socket pair: %s",
-          strerror(errno));
-    return ErrorStatus(errno);
-  }
-
-  LocalHandle local_socket{channel_pair[0]};
-  LocalHandle remote_socket{channel_pair[1]};
-
-  int optval = 1;
-  if (setsockopt(local_socket.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
-                 sizeof(optval)) == -1) {
-    ALOGE(
-        "Endpoint::PushChannel: Failed to enable the receiving of the "
-        "credentials for channel %d: %s",
-        local_socket.Get(), strerror(errno));
-    return ErrorStatus(errno);
-  }
+  LocalHandle local_socket;
+  LocalHandle remote_socket;
+  auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
+  if (!status)
+    return status.error_status();
 
   std::lock_guard<std::mutex> autolock(channel_mutex_);
   auto channel_data = OnNewChannelLocked(std::move(local_socket), channel);