Merge "Add ReadOrAgain and WriteOrAgain methods to FuseMessage."
diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc
index 8fb2dbc..13cfc88 100644
--- a/libappfuse/FuseBuffer.cc
+++ b/libappfuse/FuseBuffer.cc
@@ -23,77 +23,132 @@
 #include <algorithm>
 #include <type_traits>
 
+#include <sys/socket.h>
+
 #include <android-base/file.h>
 #include <android-base/logging.h>
 #include <android-base/macros.h>
 
 namespace android {
 namespace fuse {
-
-static_assert(
-    std::is_standard_layout<FuseBuffer>::value,
-    "FuseBuffer must be standard layout union.");
+namespace {
 
 template <typename T>
-bool FuseMessage<T>::CheckHeaderLength(const char* name) const {
-  const auto& header = static_cast<const T*>(this)->header;
-  if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+bool CheckHeaderLength(const FuseMessage<T>* self, const char* name) {
+    const auto& header = static_cast<const T*>(self)->header;
+    if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+        return true;
+    } else {
+        LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len;
+        return false;
+    }
+}
+
+template <typename T>
+ResultOrAgain ReadInternal(FuseMessage<T>* self, int fd, int sockflag) {
+    char* const buf = reinterpret_cast<char*>(self);
+    const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag))
+                                    : TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T)));
+
+    switch (result) {
+        case 0:
+            // Expected EOF.
+            return ResultOrAgain::kFailure;
+        case -1:
+            if (errno == EAGAIN) {
+                return ResultOrAgain::kAgain;
+            }
+            PLOG(ERROR) << "Failed to read a FUSE message";
+            return ResultOrAgain::kFailure;
+    }
+
+    const auto& header = static_cast<const T*>(self)->header;
+    if (result < static_cast<ssize_t>(sizeof(header))) {
+        LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header);
+        return ResultOrAgain::kFailure;
+    }
+
+    if (!CheckHeaderLength<T>(self, "Read")) {
+        return ResultOrAgain::kFailure;
+    }
+
+    if (static_cast<uint32_t>(result) != header.len) {
+        LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len;
+        return ResultOrAgain::kFailure;
+    }
+
+    return ResultOrAgain::kSuccess;
+}
+
+template <typename T>
+ResultOrAgain WriteInternal(const FuseMessage<T>* self, int fd, int sockflag) {
+    if (!CheckHeaderLength<T>(self, "Write")) {
+        return ResultOrAgain::kFailure;
+    }
+
+    const char* const buf = reinterpret_cast<const char*>(self);
+    const auto& header = static_cast<const T*>(self)->header;
+    const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag))
+                                : TEMP_FAILURE_RETRY(write(fd, buf, header.len));
+
+    if (result == -1) {
+        if (errno == EAGAIN) {
+            return ResultOrAgain::kAgain;
+        }
+        PLOG(ERROR) << "Failed to write a FUSE message";
+        return ResultOrAgain::kFailure;
+    }
+
+    CHECK(static_cast<uint32_t>(result) == header.len);
+    return ResultOrAgain::kSuccess;
+}
+}
+
+static_assert(std::is_standard_layout<FuseBuffer>::value,
+              "FuseBuffer must be standard layout union.");
+
+bool SetupMessageSockets(base::unique_fd (*result)[2]) {
+    base::unique_fd fds[2];
+    {
+        int raw_fds[2];
+        if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) {
+            PLOG(ERROR) << "Failed to create sockets for proxy";
+            return false;
+        }
+        fds[0].reset(raw_fds[0]);
+        fds[1].reset(raw_fds[1]);
+    }
+
+    constexpr int kMaxMessageSize = sizeof(FuseBuffer);
+    if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 ||
+        setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) {
+        PLOG(ERROR) << "Failed to update buffer size for socket";
+        return false;
+    }
+
+    (*result)[0] = std::move(fds[0]);
+    (*result)[1] = std::move(fds[1]);
     return true;
-  } else {
-    LOG(ERROR) << "Invalid header length is found in " << name << ": " <<
-        header.len;
-    return false;
-  }
 }
 
 template <typename T>
 bool FuseMessage<T>::Read(int fd) {
-  char* const buf = reinterpret_cast<char*>(this);
-  const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T)));
-  if (result < 0) {
-    PLOG(ERROR) << "Failed to read a FUSE message";
-    return false;
-  }
+    return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
 
-  const auto& header = static_cast<const T*>(this)->header;
-  if (result < static_cast<ssize_t>(sizeof(header))) {
-    LOG(ERROR) << "Read bytes " << result << " are shorter than header size " <<
-        sizeof(header);
-    return false;
-  }
-
-  if (!CheckHeaderLength("Read")) {
-    return false;
-  }
-
-  if (static_cast<uint32_t>(result) > header.len) {
-    LOG(ERROR) << "Read bytes " << result << " are longer than header.len " <<
-        header.len;
-    return false;
-  }
-
-  if (!base::ReadFully(fd, buf + result, header.len - result)) {
-    PLOG(ERROR) << "ReadFully failed";
-    return false;
-  }
-
-  return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::ReadOrAgain(int fd) {
+    return ReadInternal(this, fd, MSG_DONTWAIT);
 }
 
 template <typename T>
 bool FuseMessage<T>::Write(int fd) const {
-  if (!CheckHeaderLength("Write")) {
-    return false;
-  }
+    return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
 
-  const char* const buf = reinterpret_cast<const char*>(this);
-  const auto& header = static_cast<const T*>(this)->header;
-  if (!base::WriteFully(fd, buf, header.len)) {
-    PLOG(ERROR) << "WriteFully failed";
-    return false;
-  }
-
-  return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::WriteOrAgain(int fd) const {
+    return WriteInternal(this, fd, MSG_DONTWAIT);
 }
 
 template class FuseMessage<FuseRequest>;
diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h
index 7abd2fa..fbb05d6 100644
--- a/libappfuse/include/libappfuse/FuseBuffer.h
+++ b/libappfuse/include/libappfuse/FuseBuffer.h
@@ -17,6 +17,7 @@
 #ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
 #define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
 
+#include <android-base/unique_fd.h>
 #include <linux/fuse.h>
 
 namespace android {
@@ -28,12 +29,24 @@
 constexpr size_t kFuseMaxRead = 128 * 1024;
 constexpr int32_t kFuseSuccess = 0;
 
+// Setup sockets to transfer FuseMessage.
+bool SetupMessageSockets(base::unique_fd (*sockets)[2]);
+
+enum class ResultOrAgain {
+    kSuccess,
+    kFailure,
+    kAgain,
+};
+
 template<typename T>
 class FuseMessage {
  public:
   bool Read(int fd);
   bool Write(int fd) const;
- private:
+  ResultOrAgain ReadOrAgain(int fd);
+  ResultOrAgain WriteOrAgain(int fd) const;
+
+private:
   bool CheckHeaderLength(const char* name) const;
 };
 
@@ -54,7 +67,7 @@
     // for FUSE_READ
     fuse_read_in read_in;
     // for FUSE_LOOKUP
-    char lookup_name[0];
+    char lookup_name[kFuseMaxWrite];
   };
   void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique);
 };
diff --git a/libappfuse/tests/FuseAppLoopTest.cc b/libappfuse/tests/FuseAppLoopTest.cc
index 25906cf..64dd813 100644
--- a/libappfuse/tests/FuseAppLoopTest.cc
+++ b/libappfuse/tests/FuseAppLoopTest.cc
@@ -109,10 +109,7 @@
 
   void SetUp() override {
     base::SetMinimumLogSeverity(base::VERBOSE);
-    int sockets[2];
-    ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, sockets));
-    sockets_[0].reset(sockets[0]);
-    sockets_[1].reset(sockets[1]);
+    ASSERT_TRUE(SetupMessageSockets(&sockets_));
     thread_ = std::thread([this] {
       StartFuseAppLoop(sockets_[1].release(), &callback_);
     });
diff --git a/libappfuse/tests/FuseBridgeLoopTest.cc b/libappfuse/tests/FuseBridgeLoopTest.cc
index e74d9e7..b4c1efb 100644
--- a/libappfuse/tests/FuseBridgeLoopTest.cc
+++ b/libappfuse/tests/FuseBridgeLoopTest.cc
@@ -50,15 +50,8 @@
 
   void SetUp() override {
     base::SetMinimumLogSeverity(base::VERBOSE);
-    int dev_sockets[2];
-    int proxy_sockets[2];
-    ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, dev_sockets));
-    ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, proxy_sockets));
-    dev_sockets_[0].reset(dev_sockets[0]);
-    dev_sockets_[1].reset(dev_sockets[1]);
-    proxy_sockets_[0].reset(proxy_sockets[0]);
-    proxy_sockets_[1].reset(proxy_sockets[1]);
-
+    ASSERT_TRUE(SetupMessageSockets(&dev_sockets_));
+    ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_));
     thread_ = std::thread([this] {
       StartFuseBridgeLoop(
           dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_);
diff --git a/libappfuse/tests/FuseBufferTest.cc b/libappfuse/tests/FuseBufferTest.cc
index 1a1abd5..ade34ac 100644
--- a/libappfuse/tests/FuseBufferTest.cc
+++ b/libappfuse/tests/FuseBufferTest.cc
@@ -112,30 +112,6 @@
   TestWriteInvalidLength(sizeof(fuse_in_header) - 1);
 }
 
-TEST(FuseMessageTest, ShortWriteAndRead) {
-  int raw_fds[2];
-  ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds));
-
-  android::base::unique_fd fds[2];
-  fds[0].reset(raw_fds[0]);
-  fds[1].reset(raw_fds[1]);
-
-  const int send_buffer_size = 1024;
-  ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size,
-                          sizeof(int)));
-
-  bool succeed = false;
-  const int sender_fd = fds[0].get();
-  std::thread thread([sender_fd, &succeed] {
-    FuseRequest request;
-    request.header.len = 1024 * 4;
-    succeed = request.Write(sender_fd);
-  });
-  thread.detach();
-  FuseRequest request;
-  ASSERT_TRUE(request.Read(fds[1]));
-}
-
 TEST(FuseResponseTest, Reset) {
   FuseResponse response;
   // Write 1 to the first ten bytes.
@@ -211,5 +187,29 @@
   EXPECT_EQ(-ENOSYS, buffer.response.header.error);
 }
 
+TEST(SetupMessageSocketsTest, Stress) {
+    constexpr int kCount = 1000;
+
+    FuseRequest request;
+    request.header.len = sizeof(FuseRequest);
+
+    base::unique_fd fds[2];
+    SetupMessageSockets(&fds);
+
+    std::thread thread([&fds] {
+        FuseRequest request;
+        for (int i = 0; i < kCount; ++i) {
+            ASSERT_TRUE(request.Read(fds[1]));
+            usleep(1000);
+        }
+    });
+
+    for (int i = 0; i < kCount; ++i) {
+        ASSERT_TRUE(request.Write(fds[0]));
+    }
+
+    thread.join();
+}
+
 } // namespace fuse
 } // namespace android