Merge "adb: add nonblocking fd Connection."
am: 812ba6a469

Change-Id: Icfe07b4c4b7511257751a9abb1c2e58b40bfebbe
diff --git a/adb/Android.bp b/adb/Android.bp
index 53bf7e3..97c9762 100644
--- a/adb/Android.bp
+++ b/adb/Android.bp
@@ -104,6 +104,7 @@
     "socket_spec.cpp",
     "sysdeps/errno.cpp",
     "transport.cpp",
+    "transport_fd.cpp",
     "transport_local.cpp",
     "transport_usb.cpp",
 ]
diff --git a/adb/transport.h b/adb/transport.h
index ae9cc02..cb20615 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -92,6 +92,8 @@
     std::string transport_name_;
     ReadCallback read_callback_;
     ErrorCallback error_callback_;
+
+    static std::unique_ptr<Connection> FromFd(unique_fd fd);
 };
 
 // Abstraction for a blocking packet transport.
diff --git a/adb/transport_benchmark.cpp b/adb/transport_benchmark.cpp
index da24aa7..022808f 100644
--- a/adb/transport_benchmark.cpp
+++ b/adb/transport_benchmark.cpp
@@ -24,13 +24,19 @@
 #include "sysdeps.h"
 #include "transport.h"
 
-#define ADB_CONNECTION_BENCHMARK(benchmark_name, ...)               \
-    BENCHMARK_TEMPLATE(benchmark_name, FdConnection, ##__VA_ARGS__) \
-        ->Arg(1)                                                    \
-        ->Arg(16384)                                                \
-        ->Arg(MAX_PAYLOAD)                                          \
+#define ADB_CONNECTION_BENCHMARK(benchmark_name, ...)                          \
+    BENCHMARK_TEMPLATE(benchmark_name, FdConnection, ##__VA_ARGS__)            \
+        ->Arg(1)                                                               \
+        ->Arg(16384)                                                           \
+        ->Arg(MAX_PAYLOAD)                                                     \
+        ->UseRealTime();                                                       \
+    BENCHMARK_TEMPLATE(benchmark_name, NonblockingFdConnection, ##__VA_ARGS__) \
+        ->Arg(1)                                                               \
+        ->Arg(16384)                                                           \
+        ->Arg(MAX_PAYLOAD)                                                     \
         ->UseRealTime()
 
+struct NonblockingFdConnection;
 template <typename ConnectionType>
 std::unique_ptr<Connection> MakeConnection(unique_fd fd);
 
@@ -40,6 +46,11 @@
     return std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection));
 }
 
+template <>
+std::unique_ptr<Connection> MakeConnection<NonblockingFdConnection>(unique_fd fd) {
+    return Connection::FromFd(std::move(fd));
+}
+
 template <typename ConnectionType>
 void BM_Connection_Unidirectional(benchmark::State& state) {
     int fds[2];
diff --git a/adb/transport_fd.cpp b/adb/transport_fd.cpp
new file mode 100644
index 0000000..85f3c52
--- /dev/null
+++ b/adb/transport_fd.cpp
@@ -0,0 +1,239 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdint.h>
+
+#include <deque>
+#include <mutex>
+#include <string>
+#include <thread>
+
+#include <android-base/logging.h>
+#include <android-base/stringprintf.h>
+#include <android-base/thread_annotations.h>
+
+#include "adb_unique_fd.h"
+#include "adb_utils.h"
+#include "sysdeps.h"
+#include "sysdeps/memory.h"
+#include "transport.h"
+#include "types.h"
+
+static void CreateWakeFds(unique_fd* read, unique_fd* write) {
+    // TODO: eventfd on linux?
+    int wake_fds[2];
+    int rc = adb_socketpair(wake_fds);
+    set_file_block_mode(wake_fds[0], false);
+    set_file_block_mode(wake_fds[1], false);
+    CHECK_EQ(0, rc);
+    *read = unique_fd(wake_fds[0]);
+    *write = unique_fd(wake_fds[1]);
+}
+
+struct NonblockingFdConnection : public Connection {
+    NonblockingFdConnection(unique_fd fd) : started_(false), fd_(std::move(fd)) {
+        set_file_block_mode(fd_.get(), false);
+        CreateWakeFds(&wake_fd_read_, &wake_fd_write_);
+    }
+
+    void SetRunning(bool value) {
+        std::lock_guard<std::mutex> lock(run_mutex_);
+        running_ = value;
+    }
+
+    bool IsRunning() {
+        std::lock_guard<std::mutex> lock(run_mutex_);
+        return running_;
+    }
+
+    void Run(std::string* error) {
+        SetRunning(true);
+        while (IsRunning()) {
+            adb_pollfd pfds[2] = {
+                {.fd = fd_.get(), .events = POLLIN},
+                {.fd = wake_fd_read_.get(), .events = POLLIN},
+            };
+
+            {
+                std::lock_guard<std::mutex> lock(this->write_mutex_);
+                if (!writable_) {
+                    pfds[0].events |= POLLOUT;
+                }
+            }
+
+            int rc = adb_poll(pfds, 2, -1);
+            if (rc == -1) {
+                *error = android::base::StringPrintf("poll failed: %s", strerror(errno));
+                return;
+            } else if (rc == 0) {
+                LOG(FATAL) << "poll timed out with an infinite timeout?";
+            }
+
+            if (pfds[0].revents) {
+                if ((pfds[0].revents & POLLOUT)) {
+                    std::lock_guard<std::mutex> lock(this->write_mutex_);
+                    WriteResult result = DispatchWrites();
+                    switch (result) {
+                        case WriteResult::Error:
+                            *error = "write failed";
+                            return;
+
+                        case WriteResult::Completed:
+                            writable_ = true;
+                            break;
+
+                        case WriteResult::TryAgain:
+                            break;
+                    }
+                }
+
+                if (pfds[0].revents & POLLIN) {
+                    // TODO: Should we be getting blocks from a free list?
+                    auto block = std::make_unique<IOVector::block_type>(MAX_PAYLOAD);
+                    rc = adb_read(fd_.get(), &(*block)[0], block->size());
+                    if (rc == -1) {
+                        *error = std::string("read failed: ") + strerror(errno);
+                        return;
+                    } else if (rc == 0) {
+                        *error = "read failed: EOF";
+                        return;
+                    }
+                    block->resize(rc);
+                    read_buffer_.append(std::move(block));
+
+                    if (!read_header_ && read_buffer_.size() >= sizeof(amessage)) {
+                        auto header_buf = read_buffer_.take_front(sizeof(amessage)).coalesce();
+                        CHECK_EQ(sizeof(amessage), header_buf.size());
+                        read_header_ = std::make_unique<amessage>();
+                        memcpy(read_header_.get(), header_buf.data(), sizeof(amessage));
+                    }
+
+                    if (read_header_ && read_buffer_.size() >= read_header_->data_length) {
+                        auto data_chain = read_buffer_.take_front(read_header_->data_length);
+
+                        // TODO: Make apacket carry around a IOVector instead of coalescing.
+                        auto payload = data_chain.coalesce<apacket::payload_type>();
+                        auto packet = std::make_unique<apacket>();
+                        packet->msg = *read_header_;
+                        packet->payload = std::move(payload);
+                        read_header_ = nullptr;
+                        read_callback_(this, std::move(packet));
+                    }
+                }
+            }
+
+            if (pfds[1].revents) {
+                uint64_t buf;
+                rc = adb_read(wake_fd_read_.get(), &buf, sizeof(buf));
+                CHECK_EQ(static_cast<int>(sizeof(buf)), rc);
+
+                // We were woken up either to add POLLOUT to our events, or to exit.
+                // Do nothing.
+            }
+        }
+    }
+
+    void Start() override final {
+        if (started_.exchange(true)) {
+            LOG(FATAL) << "Connection started multiple times?";
+        }
+
+        thread_ = std::thread([this]() {
+            std::string error = "connection closed";
+            Run(&error);
+            this->error_callback_(this, error);
+        });
+    }
+
+    void Stop() override final {
+        SetRunning(false);
+        WakeThread();
+        thread_.join();
+    }
+
+    void WakeThread() {
+        uint64_t buf = 0;
+        if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) {
+            LOG(FATAL) << "failed to wake up thread";
+        }
+    }
+
+    enum class WriteResult {
+        Error,
+        Completed,
+        TryAgain,
+    };
+
+    WriteResult DispatchWrites() REQUIRES(write_mutex_) {
+        CHECK(!write_buffer_.empty());
+        if (!writable_) {
+            return WriteResult::TryAgain;
+        }
+
+        auto iovs = write_buffer_.iovecs();
+        ssize_t rc = adb_writev(fd_.get(), iovs.data(), iovs.size());
+        if (rc == -1) {
+            return WriteResult::Error;
+        } else if (rc == 0) {
+            errno = 0;
+            return WriteResult::Error;
+        }
+
+        // TODO: Implement a more efficient drop_front?
+        write_buffer_.take_front(rc);
+        if (write_buffer_.empty()) {
+            return WriteResult::Completed;
+        }
+
+        // There's data left in the range, which means our write returned early.
+        return WriteResult::TryAgain;
+    }
+
+    bool Write(std::unique_ptr<apacket> packet) final {
+        std::lock_guard<std::mutex> lock(write_mutex_);
+        const char* header_begin = reinterpret_cast<const char*>(&packet->msg);
+        const char* header_end = header_begin + sizeof(packet->msg);
+        auto header_block = std::make_unique<IOVector::block_type>(header_begin, header_end);
+        write_buffer_.append(std::move(header_block));
+        if (!packet->payload.empty()) {
+            write_buffer_.append(std::make_unique<IOVector::block_type>(std::move(packet->payload)));
+        }
+        return DispatchWrites() != WriteResult::Error;
+    }
+
+    std::thread thread_;
+
+    std::atomic<bool> started_;
+    std::mutex run_mutex_;
+    bool running_ GUARDED_BY(run_mutex_);
+
+    std::unique_ptr<amessage> read_header_;
+    IOVector read_buffer_;
+
+    unique_fd fd_;
+    unique_fd wake_fd_read_;
+    unique_fd wake_fd_write_;
+
+    std::mutex write_mutex_;
+    bool writable_ GUARDED_BY(write_mutex_) = true;
+    IOVector write_buffer_ GUARDED_BY(write_mutex_);
+
+    IOVector incoming_queue_;
+};
+
+std::unique_ptr<Connection> Connection::FromFd(unique_fd fd) {
+    return std::make_unique<NonblockingFdConnection>(std::move(fd));
+}