Enable FuseBridgeLoop to accept new mount point after starting

The CL turns StartFuseBridgeLoop function into FuseBridgeLoop class, and
adds a method adding new appfuse mount to the loop.

After doing this, one FuseBridgeLoop can process FUSE commands from
multiple AppFuse mounts.

Bug: 34903085
Test: libappfuse_test
Change-Id: I54f11f54bc26c551281b9c32e9bb91f8f043774c
diff --git a/libappfuse/FuseBridgeLoop.cc b/libappfuse/FuseBridgeLoop.cc
index 2386bf8..3f47066 100644
--- a/libappfuse/FuseBridgeLoop.cc
+++ b/libappfuse/FuseBridgeLoop.cc
@@ -16,85 +16,355 @@
 
 #include "libappfuse/FuseBridgeLoop.h"
 
+#include <sys/epoll.h>
+#include <sys/socket.h>
+
+#include <unordered_map>
+
 #include <android-base/logging.h>
 #include <android-base/unique_fd.h>
 
+#include "libappfuse/EpollController.h"
+
 namespace android {
 namespace fuse {
+namespace {
 
-bool StartFuseBridgeLoop(
-    int raw_dev_fd, int raw_proxy_fd, FuseBridgeLoopCallback* callback) {
-  base::unique_fd dev_fd(raw_dev_fd);
-  base::unique_fd proxy_fd(raw_proxy_fd);
-  FuseBuffer buffer;
-  size_t open_count = 0;
+enum class FuseBridgeState { kWaitToReadEither, kWaitToReadProxy, kWaitToWriteProxy, kClosing };
 
-  LOG(DEBUG) << "Start fuse loop.";
-  while (true) {
-    if (!buffer.request.Read(dev_fd)) {
-      return false;
+struct FuseBridgeEntryEvent {
+    FuseBridgeEntry* entry;
+    int events;
+};
+
+void GetObservedEvents(FuseBridgeState state, int* device_events, int* proxy_events) {
+    switch (state) {
+        case FuseBridgeState::kWaitToReadEither:
+            *device_events = EPOLLIN;
+            *proxy_events = EPOLLIN;
+            return;
+        case FuseBridgeState::kWaitToReadProxy:
+            *device_events = 0;
+            *proxy_events = EPOLLIN;
+            return;
+        case FuseBridgeState::kWaitToWriteProxy:
+            *device_events = 0;
+            *proxy_events = EPOLLOUT;
+            return;
+        case FuseBridgeState::kClosing:
+            *device_events = 0;
+            *proxy_events = 0;
+            return;
+    }
+}
+}
+
+class FuseBridgeEntry {
+  public:
+    FuseBridgeEntry(int mount_id, base::unique_fd&& dev_fd, base::unique_fd&& proxy_fd)
+        : mount_id_(mount_id),
+          device_fd_(std::move(dev_fd)),
+          proxy_fd_(std::move(proxy_fd)),
+          state_(FuseBridgeState::kWaitToReadEither),
+          last_state_(FuseBridgeState::kWaitToReadEither),
+          last_device_events_({this, 0}),
+          last_proxy_events_({this, 0}),
+          open_count_(0) {}
+
+    // Transfer bytes depends on availability of FDs and the internal |state_|.
+    void Transfer(FuseBridgeLoopCallback* callback) {
+        constexpr int kUnexpectedEventMask = ~(EPOLLIN | EPOLLOUT);
+        const bool unexpected_event = (last_device_events_.events & kUnexpectedEventMask) ||
+                                      (last_proxy_events_.events & kUnexpectedEventMask);
+        const bool device_read_ready = last_device_events_.events & EPOLLIN;
+        const bool proxy_read_ready = last_proxy_events_.events & EPOLLIN;
+        const bool proxy_write_ready = last_proxy_events_.events & EPOLLOUT;
+
+        last_device_events_.events = 0;
+        last_proxy_events_.events = 0;
+
+        LOG(VERBOSE) << "Transfer device_read_ready=" << device_read_ready
+                     << " proxy_read_ready=" << proxy_read_ready
+                     << " proxy_write_ready=" << proxy_write_ready;
+
+        if (unexpected_event) {
+            LOG(ERROR) << "Invalid epoll event is observed";
+            state_ = FuseBridgeState::kClosing;
+            return;
+        }
+
+        switch (state_) {
+            case FuseBridgeState::kWaitToReadEither:
+                if (proxy_read_ready) {
+                    state_ = ReadFromProxy();
+                } else if (device_read_ready) {
+                    state_ = ReadFromDevice(callback);
+                }
+                return;
+
+            case FuseBridgeState::kWaitToReadProxy:
+                CHECK(proxy_read_ready);
+                state_ = ReadFromProxy();
+                return;
+
+            case FuseBridgeState::kWaitToWriteProxy:
+                CHECK(proxy_write_ready);
+                state_ = WriteToProxy();
+                return;
+
+            case FuseBridgeState::kClosing:
+                return;
+        }
     }
 
-    const uint32_t opcode = buffer.request.header.opcode;
-    LOG(VERBOSE) << "Read a fuse packet, opcode=" << opcode;
-    switch (opcode) {
-      case FUSE_FORGET:
-        // Do not reply to FUSE_FORGET.
-        continue;
+    bool IsClosing() const { return state_ == FuseBridgeState::kClosing; }
 
-      case FUSE_LOOKUP:
-      case FUSE_GETATTR:
-      case FUSE_OPEN:
-      case FUSE_READ:
-      case FUSE_WRITE:
-      case FUSE_RELEASE:
-      case FUSE_FSYNC:
-        if (!buffer.request.Write(proxy_fd)) {
-          LOG(ERROR) << "Failed to write a request to the proxy.";
-          return false;
+    int mount_id() const { return mount_id_; }
+
+  private:
+    friend class BridgeEpollController;
+
+    FuseBridgeState ReadFromProxy() {
+        switch (buffer_.response.ReadOrAgain(proxy_fd_)) {
+            case ResultOrAgain::kSuccess:
+                break;
+            case ResultOrAgain::kFailure:
+                return FuseBridgeState::kClosing;
+            case ResultOrAgain::kAgain:
+                return FuseBridgeState::kWaitToReadProxy;
         }
-        if (!buffer.response.Read(proxy_fd)) {
-          LOG(ERROR) << "Failed to read a response from the proxy.";
-          return false;
+
+        if (!buffer_.response.Write(device_fd_)) {
+            return FuseBridgeState::kClosing;
         }
-        break;
 
-      case FUSE_INIT:
-        buffer.HandleInit();
-        break;
+        auto it = opcode_map_.find(buffer_.response.header.unique);
+        if (it != opcode_map_.end()) {
+            switch (it->second) {
+                case FUSE_OPEN:
+                    if (buffer_.response.header.error == fuse::kFuseSuccess) {
+                        open_count_++;
+                    }
+                    break;
 
-      default:
-        buffer.HandleNotImpl();
-        break;
+                case FUSE_RELEASE:
+                    if (open_count_ > 0) {
+                        open_count_--;
+                    } else {
+                        LOG(WARNING) << "Unexpected FUSE_RELEASE before opening a file.";
+                        break;
+                    }
+                    if (open_count_ == 0) {
+                        return FuseBridgeState::kClosing;
+                    }
+                    break;
+            }
+            opcode_map_.erase(it);
+        }
+
+        return FuseBridgeState::kWaitToReadEither;
     }
 
-    if (!buffer.response.Write(dev_fd)) {
-      LOG(ERROR) << "Failed to write a response to the device.";
-      return false;
+    FuseBridgeState ReadFromDevice(FuseBridgeLoopCallback* callback) {
+        LOG(VERBOSE) << "ReadFromDevice";
+        if (!buffer_.request.Read(device_fd_)) {
+            return FuseBridgeState::kClosing;
+        }
+
+        const uint32_t opcode = buffer_.request.header.opcode;
+        LOG(VERBOSE) << "Read a fuse packet, opcode=" << opcode;
+        switch (opcode) {
+            case FUSE_FORGET:
+                // Do not reply to FUSE_FORGET.
+                return FuseBridgeState::kWaitToReadEither;
+
+            case FUSE_LOOKUP:
+            case FUSE_GETATTR:
+            case FUSE_OPEN:
+            case FUSE_READ:
+            case FUSE_WRITE:
+            case FUSE_RELEASE:
+            case FUSE_FSYNC:
+                if (opcode == FUSE_OPEN || opcode == FUSE_RELEASE) {
+                    opcode_map_.emplace(buffer_.request.header.unique, opcode);
+                }
+                return WriteToProxy();
+
+            case FUSE_INIT:
+                buffer_.HandleInit();
+                break;
+
+            default:
+                buffer_.HandleNotImpl();
+                break;
+        }
+
+        if (!buffer_.response.Write(device_fd_)) {
+            return FuseBridgeState::kClosing;
+        }
+
+        if (opcode == FUSE_INIT) {
+            callback->OnMount(mount_id_);
+        }
+
+        return FuseBridgeState::kWaitToReadEither;
     }
 
-    switch (opcode) {
-      case FUSE_INIT:
-        callback->OnMount();
-        break;
-      case FUSE_OPEN:
-        if (buffer.response.header.error == fuse::kFuseSuccess) {
-          open_count++;
+    FuseBridgeState WriteToProxy() {
+        switch (buffer_.request.WriteOrAgain(proxy_fd_)) {
+            case ResultOrAgain::kSuccess:
+                return FuseBridgeState::kWaitToReadEither;
+            case ResultOrAgain::kFailure:
+                return FuseBridgeState::kClosing;
+            case ResultOrAgain::kAgain:
+                return FuseBridgeState::kWaitToWriteProxy;
         }
-        break;
-      case FUSE_RELEASE:
-        if (open_count != 0) {
-            open_count--;
-        } else {
-            LOG(WARNING) << "Unexpected FUSE_RELEASE before opening a file.";
-            break;
-        }
-        if (open_count == 0) {
-          return true;
-        }
-        break;
     }
-  }
+
+    const int mount_id_;
+    base::unique_fd device_fd_;
+    base::unique_fd proxy_fd_;
+    FuseBuffer buffer_;
+    FuseBridgeState state_;
+    FuseBridgeState last_state_;
+    FuseBridgeEntryEvent last_device_events_;
+    FuseBridgeEntryEvent last_proxy_events_;
+
+    // Remember map between unique and opcode in fuse_in_header so that we can
+    // refer the opcode later.
+    std::unordered_map<uint64_t, uint32_t> opcode_map_;
+
+    int open_count_;
+
+    DISALLOW_COPY_AND_ASSIGN(FuseBridgeEntry);
+};
+
+class BridgeEpollController : private EpollController {
+  public:
+    BridgeEpollController(base::unique_fd&& poll_fd) : EpollController(std::move(poll_fd)) {}
+
+    bool AddBridgePoll(FuseBridgeEntry* bridge) const {
+        return InvokeControl(EPOLL_CTL_ADD, bridge);
+    }
+
+    bool UpdateOrDeleteBridgePoll(FuseBridgeEntry* bridge) const {
+        return InvokeControl(
+            bridge->state_ != FuseBridgeState::kClosing ? EPOLL_CTL_MOD : EPOLL_CTL_DEL, bridge);
+    }
+
+    bool Wait(size_t bridge_count, std::unordered_set<FuseBridgeEntry*>* entries_out) {
+        CHECK(entries_out);
+        const size_t event_count = std::max<size_t>(bridge_count * 2, 1);
+        if (!EpollController::Wait(event_count)) {
+            return false;
+        }
+        entries_out->clear();
+        for (const auto& event : events()) {
+            FuseBridgeEntryEvent* const entry_event =
+                reinterpret_cast<FuseBridgeEntryEvent*>(event.data.ptr);
+            entry_event->events = event.events;
+            entries_out->insert(entry_event->entry);
+        }
+        return true;
+    }
+
+  private:
+    bool InvokeControl(int op, FuseBridgeEntry* bridge) const {
+        LOG(VERBOSE) << "InvokeControl op=" << op << " bridge=" << bridge->mount_id_
+                     << " state=" << static_cast<int>(bridge->state_)
+                     << " last_state=" << static_cast<int>(bridge->last_state_);
+
+        int last_device_events;
+        int last_proxy_events;
+        int device_events;
+        int proxy_events;
+        GetObservedEvents(bridge->last_state_, &last_device_events, &last_proxy_events);
+        GetObservedEvents(bridge->state_, &device_events, &proxy_events);
+        bool result = true;
+        if (op != EPOLL_CTL_MOD || last_device_events != device_events) {
+            result &= EpollController::InvokeControl(op, bridge->device_fd_, device_events,
+                                                     &bridge->last_device_events_);
+        }
+        if (op != EPOLL_CTL_MOD || last_proxy_events != proxy_events) {
+            result &= EpollController::InvokeControl(op, bridge->proxy_fd_, proxy_events,
+                                                     &bridge->last_proxy_events_);
+        }
+        return result;
+    }
+};
+
+FuseBridgeLoop::FuseBridgeLoop() : opened_(true) {
+    base::unique_fd epoll_fd(epoll_create1(/* no flag */ 0));
+    if (epoll_fd.get() == -1) {
+        PLOG(ERROR) << "Failed to open FD for epoll";
+        opened_ = false;
+        return;
+    }
+    epoll_controller_.reset(new BridgeEpollController(std::move(epoll_fd)));
+}
+
+FuseBridgeLoop::~FuseBridgeLoop() { CHECK(bridges_.empty()); }
+
+bool FuseBridgeLoop::AddBridge(int mount_id, base::unique_fd dev_fd, base::unique_fd proxy_fd) {
+    LOG(VERBOSE) << "Adding bridge " << mount_id;
+
+    std::unique_ptr<FuseBridgeEntry> bridge(
+        new FuseBridgeEntry(mount_id, std::move(dev_fd), std::move(proxy_fd)));
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (!opened_) {
+        LOG(ERROR) << "Tried to add a mount to a closed bridge";
+        return false;
+    }
+    if (bridges_.count(mount_id)) {
+        LOG(ERROR) << "Tried to add a mount point that has already been added";
+        return false;
+    }
+    if (!epoll_controller_->AddBridgePoll(bridge.get())) {
+        return false;
+    }
+
+    bridges_.emplace(mount_id, std::move(bridge));
+    return true;
+}
+
+bool FuseBridgeLoop::ProcessEventLocked(const std::unordered_set<FuseBridgeEntry*>& entries,
+                                        FuseBridgeLoopCallback* callback) {
+    for (auto entry : entries) {
+        entry->Transfer(callback);
+        if (!epoll_controller_->UpdateOrDeleteBridgePoll(entry)) {
+            return false;
+        }
+        if (entry->IsClosing()) {
+            const int mount_id = entry->mount_id();
+            callback->OnClosed(mount_id);
+            bridges_.erase(mount_id);
+            if (bridges_.size() == 0) {
+                // All bridges are now closed.
+                return false;
+            }
+        }
+    }
+    return true;
+}
+
+void FuseBridgeLoop::Start(FuseBridgeLoopCallback* callback) {
+    LOG(DEBUG) << "Start fuse bridge loop";
+    std::unordered_set<FuseBridgeEntry*> entries;
+    while (true) {
+        const bool wait_result = epoll_controller_->Wait(bridges_.size(), &entries);
+        LOG(VERBOSE) << "Receive epoll events";
+        {
+            std::lock_guard<std::mutex> lock(mutex_);
+            if (!(wait_result && ProcessEventLocked(entries, callback))) {
+                for (auto it = bridges_.begin(); it != bridges_.end();) {
+                    callback->OnClosed(it->second->mount_id());
+                    it = bridges_.erase(it);
+                }
+                opened_ = false;
+                return;
+            }
+        }
+    }
 }
 
 }  // namespace fuse
diff --git a/libappfuse/include/libappfuse/EpollController.h b/libappfuse/include/libappfuse/EpollController.h
index 3863aba..622bd2c 100644
--- a/libappfuse/include/libappfuse/EpollController.h
+++ b/libappfuse/include/libappfuse/EpollController.h
@@ -37,8 +37,10 @@
 
     const std::vector<epoll_event>& events() const;
 
-  private:
+  protected:
     bool InvokeControl(int op, int fd, int events, void* data) const;
+
+  private:
     base::unique_fd poll_fd_;
     std::vector<epoll_event> events_;
 
diff --git a/libappfuse/include/libappfuse/FuseBridgeLoop.h b/libappfuse/include/libappfuse/FuseBridgeLoop.h
index 1f71cf2..6bfda98 100644
--- a/libappfuse/include/libappfuse/FuseBridgeLoop.h
+++ b/libappfuse/include/libappfuse/FuseBridgeLoop.h
@@ -17,6 +17,13 @@
 #ifndef ANDROID_LIBAPPFUSE_FUSEBRIDGELOOP_H_
 #define ANDROID_LIBAPPFUSE_FUSEBRIDGELOOP_H_
 
+#include <map>
+#include <mutex>
+#include <queue>
+#include <unordered_set>
+
+#include <android-base/macros.h>
+
 #include "libappfuse/FuseBuffer.h"
 
 namespace android {
@@ -24,12 +31,41 @@
 
 class FuseBridgeLoopCallback {
  public:
-  virtual void OnMount() = 0;
-  virtual ~FuseBridgeLoopCallback() = default;
+   virtual void OnMount(int mount_id) = 0;
+   virtual void OnClosed(int mount_id) = 0;
+   virtual ~FuseBridgeLoopCallback() = default;
 };
 
-bool StartFuseBridgeLoop(
-    int dev_fd, int proxy_fd, FuseBridgeLoopCallback* callback);
+class FuseBridgeEntry;
+class BridgeEpollController;
+
+class FuseBridgeLoop final {
+  public:
+    FuseBridgeLoop();
+    ~FuseBridgeLoop();
+
+    void Start(FuseBridgeLoopCallback* callback);
+
+    // Add bridge to the loop. It's OK to invoke the method from a different
+    // thread from one which invokes |Start|.
+    bool AddBridge(int mount_id, base::unique_fd dev_fd, base::unique_fd proxy_fd);
+
+  private:
+    bool ProcessEventLocked(const std::unordered_set<FuseBridgeEntry*>& entries,
+                            FuseBridgeLoopCallback* callback);
+
+    std::unique_ptr<BridgeEpollController> epoll_controller_;
+
+    // Map between |mount_id| and bridge entry.
+    std::map<int, std::unique_ptr<FuseBridgeEntry>> bridges_;
+
+    // Lock for multi-threading.
+    std::mutex mutex_;
+
+    bool opened_;
+
+    DISALLOW_COPY_AND_ASSIGN(FuseBridgeLoop);
+};
 
 }  // namespace fuse
 }  // namespace android
diff --git a/libappfuse/tests/FuseBridgeLoopTest.cc b/libappfuse/tests/FuseBridgeLoopTest.cc
index b4c1efb..51d6051 100644
--- a/libappfuse/tests/FuseBridgeLoopTest.cc
+++ b/libappfuse/tests/FuseBridgeLoopTest.cc
@@ -32,10 +32,12 @@
 class Callback : public FuseBridgeLoopCallback {
  public:
   bool mounted;
-  Callback() : mounted(false) {}
-  void OnMount() override {
-    mounted = true;
-  }
+  bool closed;
+  Callback() : mounted(false), closed(false) {}
+
+  void OnMount(int /*mount_id*/) override { mounted = true; }
+
+  void OnClosed(int /* mount_id */) override { closed = true; }
 };
 
 class FuseBridgeLoopTest : public ::testing::Test {
@@ -53,8 +55,9 @@
     ASSERT_TRUE(SetupMessageSockets(&dev_sockets_));
     ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_));
     thread_ = std::thread([this] {
-      StartFuseBridgeLoop(
-          dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_);
+        FuseBridgeLoop loop;
+        loop.AddBridge(1, std::move(dev_sockets_[1]), std::move(proxy_sockets_[0]));
+        loop.Start(&callback_);
     });
   }
 
@@ -115,6 +118,7 @@
     if (thread_.joinable()) {
       thread_.join();
     }
+    ASSERT_TRUE(callback_.closed);
   }
 
   void TearDown() override {