Merge "Add ReadOrAgain and WriteOrAgain methods to FuseMessage."
diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc
index 8fb2dbc..13cfc88 100644
--- a/libappfuse/FuseBuffer.cc
+++ b/libappfuse/FuseBuffer.cc
@@ -23,77 +23,132 @@
#include <algorithm>
#include <type_traits>
+#include <sys/socket.h>
+
#include <android-base/file.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
namespace android {
namespace fuse {
-
-static_assert(
- std::is_standard_layout<FuseBuffer>::value,
- "FuseBuffer must be standard layout union.");
+namespace {
template <typename T>
-bool FuseMessage<T>::CheckHeaderLength(const char* name) const {
- const auto& header = static_cast<const T*>(this)->header;
- if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+bool CheckHeaderLength(const FuseMessage<T>* self, const char* name) {
+ const auto& header = static_cast<const T*>(self)->header;
+ if (header.len >= sizeof(header) && header.len <= sizeof(T)) {
+ return true;
+ } else {
+ LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len;
+ return false;
+ }
+}
+
+template <typename T>
+ResultOrAgain ReadInternal(FuseMessage<T>* self, int fd, int sockflag) {
+ char* const buf = reinterpret_cast<char*>(self);
+ const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag))
+ : TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T)));
+
+ switch (result) {
+ case 0:
+ // Expected EOF.
+ return ResultOrAgain::kFailure;
+ case -1:
+ if (errno == EAGAIN) {
+ return ResultOrAgain::kAgain;
+ }
+ PLOG(ERROR) << "Failed to read a FUSE message";
+ return ResultOrAgain::kFailure;
+ }
+
+ const auto& header = static_cast<const T*>(self)->header;
+ if (result < static_cast<ssize_t>(sizeof(header))) {
+ LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header);
+ return ResultOrAgain::kFailure;
+ }
+
+ if (!CheckHeaderLength<T>(self, "Read")) {
+ return ResultOrAgain::kFailure;
+ }
+
+ if (static_cast<uint32_t>(result) != header.len) {
+ LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len;
+ return ResultOrAgain::kFailure;
+ }
+
+ return ResultOrAgain::kSuccess;
+}
+
+template <typename T>
+ResultOrAgain WriteInternal(const FuseMessage<T>* self, int fd, int sockflag) {
+ if (!CheckHeaderLength<T>(self, "Write")) {
+ return ResultOrAgain::kFailure;
+ }
+
+ const char* const buf = reinterpret_cast<const char*>(self);
+ const auto& header = static_cast<const T*>(self)->header;
+ const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag))
+ : TEMP_FAILURE_RETRY(write(fd, buf, header.len));
+
+ if (result == -1) {
+ if (errno == EAGAIN) {
+ return ResultOrAgain::kAgain;
+ }
+ PLOG(ERROR) << "Failed to write a FUSE message";
+ return ResultOrAgain::kFailure;
+ }
+
+ CHECK(static_cast<uint32_t>(result) == header.len);
+ return ResultOrAgain::kSuccess;
+}
+}
+
+static_assert(std::is_standard_layout<FuseBuffer>::value,
+ "FuseBuffer must be standard layout union.");
+
+bool SetupMessageSockets(base::unique_fd (*result)[2]) {
+ base::unique_fd fds[2];
+ {
+ int raw_fds[2];
+ if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) {
+ PLOG(ERROR) << "Failed to create sockets for proxy";
+ return false;
+ }
+ fds[0].reset(raw_fds[0]);
+ fds[1].reset(raw_fds[1]);
+ }
+
+ constexpr int kMaxMessageSize = sizeof(FuseBuffer);
+ if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 ||
+ setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) {
+ PLOG(ERROR) << "Failed to update buffer size for socket";
+ return false;
+ }
+
+ (*result)[0] = std::move(fds[0]);
+ (*result)[1] = std::move(fds[1]);
return true;
- } else {
- LOG(ERROR) << "Invalid header length is found in " << name << ": " <<
- header.len;
- return false;
- }
}
template <typename T>
bool FuseMessage<T>::Read(int fd) {
- char* const buf = reinterpret_cast<char*>(this);
- const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T)));
- if (result < 0) {
- PLOG(ERROR) << "Failed to read a FUSE message";
- return false;
- }
+ return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
- const auto& header = static_cast<const T*>(this)->header;
- if (result < static_cast<ssize_t>(sizeof(header))) {
- LOG(ERROR) << "Read bytes " << result << " are shorter than header size " <<
- sizeof(header);
- return false;
- }
-
- if (!CheckHeaderLength("Read")) {
- return false;
- }
-
- if (static_cast<uint32_t>(result) > header.len) {
- LOG(ERROR) << "Read bytes " << result << " are longer than header.len " <<
- header.len;
- return false;
- }
-
- if (!base::ReadFully(fd, buf + result, header.len - result)) {
- PLOG(ERROR) << "ReadFully failed";
- return false;
- }
-
- return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::ReadOrAgain(int fd) {
+ return ReadInternal(this, fd, MSG_DONTWAIT);
}
template <typename T>
bool FuseMessage<T>::Write(int fd) const {
- if (!CheckHeaderLength("Write")) {
- return false;
- }
+ return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess;
+}
- const char* const buf = reinterpret_cast<const char*>(this);
- const auto& header = static_cast<const T*>(this)->header;
- if (!base::WriteFully(fd, buf, header.len)) {
- PLOG(ERROR) << "WriteFully failed";
- return false;
- }
-
- return true;
+template <typename T>
+ResultOrAgain FuseMessage<T>::WriteOrAgain(int fd) const {
+ return WriteInternal(this, fd, MSG_DONTWAIT);
}
template class FuseMessage<FuseRequest>;
diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h
index 7abd2fa..fbb05d6 100644
--- a/libappfuse/include/libappfuse/FuseBuffer.h
+++ b/libappfuse/include/libappfuse/FuseBuffer.h
@@ -17,6 +17,7 @@
#ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
#define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_
+#include <android-base/unique_fd.h>
#include <linux/fuse.h>
namespace android {
@@ -28,12 +29,24 @@
constexpr size_t kFuseMaxRead = 128 * 1024;
constexpr int32_t kFuseSuccess = 0;
+// Setup sockets to transfer FuseMessage.
+bool SetupMessageSockets(base::unique_fd (*sockets)[2]);
+
+enum class ResultOrAgain {
+ kSuccess,
+ kFailure,
+ kAgain,
+};
+
template<typename T>
class FuseMessage {
public:
bool Read(int fd);
bool Write(int fd) const;
- private:
+ ResultOrAgain ReadOrAgain(int fd);
+ ResultOrAgain WriteOrAgain(int fd) const;
+
+private:
bool CheckHeaderLength(const char* name) const;
};
@@ -54,7 +67,7 @@
// for FUSE_READ
fuse_read_in read_in;
// for FUSE_LOOKUP
- char lookup_name[0];
+ char lookup_name[kFuseMaxWrite];
};
void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique);
};
diff --git a/libappfuse/tests/FuseAppLoopTest.cc b/libappfuse/tests/FuseAppLoopTest.cc
index 25906cf..64dd813 100644
--- a/libappfuse/tests/FuseAppLoopTest.cc
+++ b/libappfuse/tests/FuseAppLoopTest.cc
@@ -109,10 +109,7 @@
void SetUp() override {
base::SetMinimumLogSeverity(base::VERBOSE);
- int sockets[2];
- ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, sockets));
- sockets_[0].reset(sockets[0]);
- sockets_[1].reset(sockets[1]);
+ ASSERT_TRUE(SetupMessageSockets(&sockets_));
thread_ = std::thread([this] {
StartFuseAppLoop(sockets_[1].release(), &callback_);
});
diff --git a/libappfuse/tests/FuseBridgeLoopTest.cc b/libappfuse/tests/FuseBridgeLoopTest.cc
index e74d9e7..b4c1efb 100644
--- a/libappfuse/tests/FuseBridgeLoopTest.cc
+++ b/libappfuse/tests/FuseBridgeLoopTest.cc
@@ -50,15 +50,8 @@
void SetUp() override {
base::SetMinimumLogSeverity(base::VERBOSE);
- int dev_sockets[2];
- int proxy_sockets[2];
- ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, dev_sockets));
- ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, proxy_sockets));
- dev_sockets_[0].reset(dev_sockets[0]);
- dev_sockets_[1].reset(dev_sockets[1]);
- proxy_sockets_[0].reset(proxy_sockets[0]);
- proxy_sockets_[1].reset(proxy_sockets[1]);
-
+ ASSERT_TRUE(SetupMessageSockets(&dev_sockets_));
+ ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_));
thread_ = std::thread([this] {
StartFuseBridgeLoop(
dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_);
diff --git a/libappfuse/tests/FuseBufferTest.cc b/libappfuse/tests/FuseBufferTest.cc
index 1a1abd5..ade34ac 100644
--- a/libappfuse/tests/FuseBufferTest.cc
+++ b/libappfuse/tests/FuseBufferTest.cc
@@ -112,30 +112,6 @@
TestWriteInvalidLength(sizeof(fuse_in_header) - 1);
}
-TEST(FuseMessageTest, ShortWriteAndRead) {
- int raw_fds[2];
- ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds));
-
- android::base::unique_fd fds[2];
- fds[0].reset(raw_fds[0]);
- fds[1].reset(raw_fds[1]);
-
- const int send_buffer_size = 1024;
- ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size,
- sizeof(int)));
-
- bool succeed = false;
- const int sender_fd = fds[0].get();
- std::thread thread([sender_fd, &succeed] {
- FuseRequest request;
- request.header.len = 1024 * 4;
- succeed = request.Write(sender_fd);
- });
- thread.detach();
- FuseRequest request;
- ASSERT_TRUE(request.Read(fds[1]));
-}
-
TEST(FuseResponseTest, Reset) {
FuseResponse response;
// Write 1 to the first ten bytes.
@@ -211,5 +187,29 @@
EXPECT_EQ(-ENOSYS, buffer.response.header.error);
}
+TEST(SetupMessageSocketsTest, Stress) {
+ constexpr int kCount = 1000;
+
+ FuseRequest request;
+ request.header.len = sizeof(FuseRequest);
+
+ base::unique_fd fds[2];
+ SetupMessageSockets(&fds);
+
+ std::thread thread([&fds] {
+ FuseRequest request;
+ for (int i = 0; i < kCount; ++i) {
+ ASSERT_TRUE(request.Read(fds[1]));
+ usleep(1000);
+ }
+ });
+
+ for (int i = 0; i < kCount; ++i) {
+ ASSERT_TRUE(request.Write(fds[0]));
+ }
+
+ thread.join();
+}
+
} // namespace fuse
} // namespace android