Merge changes I1abd671f,I9ae61465 am: 0220ca7d09
am: 3839ff3bf9

Change-Id: I9500393c2756c35c9eb8edb2ce018cae76716a1e
diff --git a/adb/fdevent.cpp b/adb/fdevent.cpp
index fa3738d..32f9086 100644
--- a/adb/fdevent.cpp
+++ b/adb/fdevent.cpp
@@ -32,6 +32,7 @@
 #include <functional>
 #include <list>
 #include <mutex>
+#include <optional>
 #include <unordered_map>
 #include <utility>
 #include <variant>
@@ -225,14 +226,22 @@
 
 void fdevent_add(fdevent* fde, unsigned events) {
     check_main_thread();
+    CHECK(!(events & FDE_TIMEOUT));
     fdevent_set(fde, (fde->state & FDE_EVENTMASK) | events);
 }
 
 void fdevent_del(fdevent* fde, unsigned events) {
     check_main_thread();
+    CHECK(!(events & FDE_TIMEOUT));
     fdevent_set(fde, (fde->state & FDE_EVENTMASK) & ~events);
 }
 
+void fdevent_set_timeout(fdevent* fde, std::optional<std::chrono::milliseconds> timeout) {
+    check_main_thread();
+    fde->timeout = timeout;
+    fde->last_active = std::chrono::steady_clock::now();
+}
+
 static std::string dump_pollfds(const std::vector<adb_pollfd>& pollfds) {
     std::string result;
     for (const auto& pollfd : pollfds) {
@@ -248,6 +257,32 @@
     return result;
 }
 
+static std::optional<std::chrono::milliseconds> calculate_timeout() {
+    std::optional<std::chrono::milliseconds> result = std::nullopt;
+    auto now = std::chrono::steady_clock::now();
+    check_main_thread();
+
+    for (const auto& [fd, pollnode] : g_poll_node_map) {
+        UNUSED(fd);
+        auto timeout_opt = pollnode.fde->timeout;
+        if (timeout_opt) {
+            auto deadline = pollnode.fde->last_active + *timeout_opt;
+            auto time_left = std::chrono::duration_cast<std::chrono::milliseconds>(deadline - now);
+            if (time_left < std::chrono::milliseconds::zero()) {
+                time_left = std::chrono::milliseconds::zero();
+            }
+
+            if (!result) {
+                result = time_left;
+            } else {
+                result = std::min(*result, time_left);
+            }
+        }
+    }
+
+    return result;
+}
+
 static void fdevent_process() {
     std::vector<adb_pollfd> pollfds;
     for (const auto& pair : g_poll_node_map) {
@@ -256,11 +291,22 @@
     CHECK_GT(pollfds.size(), 0u);
     D("poll(), pollfds = %s", dump_pollfds(pollfds).c_str());
 
-    int ret = adb_poll(&pollfds[0], pollfds.size(), -1);
+    auto timeout = calculate_timeout();
+    int timeout_ms;
+    if (!timeout) {
+        timeout_ms = -1;
+    } else {
+        timeout_ms = timeout->count();
+    }
+
+    int ret = adb_poll(&pollfds[0], pollfds.size(), timeout_ms);
     if (ret == -1) {
         PLOG(ERROR) << "poll(), ret = " << ret;
         return;
     }
+
+    auto post_poll = std::chrono::steady_clock::now();
+
     for (const auto& pollfd : pollfds) {
         if (pollfd.revents != 0) {
             D("for fd %d, revents = %x", pollfd.fd, pollfd.revents);
@@ -282,12 +328,24 @@
             events |= FDE_READ | FDE_ERROR;
         }
 #endif
+        auto it = g_poll_node_map.find(pollfd.fd);
+        CHECK(it != g_poll_node_map.end());
+        fdevent* fde = it->second.fde;
+
+        if (events == 0) {
+            // Check for timeout.
+            if (fde->timeout) {
+                auto deadline = fde->last_active + *fde->timeout;
+                if (deadline < post_poll) {
+                    events |= FDE_TIMEOUT;
+                }
+            }
+        }
+
         if (events != 0) {
-            auto it = g_poll_node_map.find(pollfd.fd);
-            CHECK(it != g_poll_node_map.end());
-            fdevent* fde = it->second.fde;
             CHECK_EQ(fde->fd.get(), pollfd.fd);
             fde->events |= events;
+            fde->last_active = post_poll;
             D("%s got events %x", dump_fde(fde).c_str(), events);
             fde->state |= FDE_PENDING;
             g_pending_list.push_back(fde);
diff --git a/adb/fdevent.h b/adb/fdevent.h
index 70e0a96..42dbb9e 100644
--- a/adb/fdevent.h
+++ b/adb/fdevent.h
@@ -18,17 +18,20 @@
 #define __FDEVENT_H
 
 #include <stddef.h>
-#include <stdint.h>  /* for int64_t */
+#include <stdint.h>
 
+#include <chrono>
 #include <functional>
+#include <optional>
 #include <variant>
 
 #include "adb_unique_fd.h"
 
-/* events that may be observed */
-#define FDE_READ              0x0001
-#define FDE_WRITE             0x0002
-#define FDE_ERROR             0x0004
+// Events that may be observed
+#define FDE_READ 0x0001
+#define FDE_WRITE 0x0002
+#define FDE_ERROR 0x0004
+#define FDE_TIMEOUT 0x0008
 
 typedef void (*fd_func)(int fd, unsigned events, void *userdata);
 typedef void (*fd_func2)(struct fdevent* fde, unsigned events, void* userdata);
@@ -41,6 +44,8 @@
 
     uint16_t state = 0;
     uint16_t events = 0;
+    std::optional<std::chrono::milliseconds> timeout;
+    std::chrono::steady_clock::time_point last_active;
 
     std::variant<fd_func, fd_func2> func;
     void* arg = nullptr;
@@ -62,7 +67,11 @@
 void fdevent_add(fdevent *fde, unsigned events);
 void fdevent_del(fdevent *fde, unsigned events);
 
-void fdevent_set_timeout(fdevent *fde, int64_t  timeout_ms);
+// Set a timeout on an fdevent.
+// If no events are triggered by the timeout, an FDE_TIMEOUT will be generated.
+// Note timeouts are not defused automatically; if a timeout is set on an fdevent, it will
+// trigger repeatedly every |timeout| ms.
+void fdevent_set_timeout(fdevent* fde, std::optional<std::chrono::milliseconds> timeout);
 
 // Loop forever, handling events.
 void fdevent_loop();
diff --git a/adb/fdevent_test.cpp b/adb/fdevent_test.cpp
index a9746bb..682f061 100644
--- a/adb/fdevent_test.cpp
+++ b/adb/fdevent_test.cpp
@@ -18,6 +18,7 @@
 
 #include <gtest/gtest.h>
 
+#include <chrono>
 #include <limits>
 #include <memory>
 #include <queue>
@@ -28,6 +29,8 @@
 #include "adb_io.h"
 #include "fdevent_test.h"
 
+using namespace std::chrono_literals;
+
 class FdHandler {
   public:
     FdHandler(int read_fd, int write_fd, bool use_new_callback)
@@ -257,3 +260,100 @@
         ASSERT_EQ(i, vec[i]);
     }
 }
+
+TEST_F(FdeventTest, timeout) {
+    fdevent_reset();
+    PrepareThread();
+
+    enum class TimeoutEvent {
+        read,
+        timeout,
+        done,
+    };
+
+    struct TimeoutTest {
+        std::vector<std::pair<TimeoutEvent, std::chrono::steady_clock::time_point>> events;
+        fdevent* fde;
+    };
+    TimeoutTest test;
+
+    int fds[2];
+    ASSERT_EQ(0, adb_socketpair(fds));
+    static constexpr auto delta = 100ms;
+    fdevent_run_on_main_thread([&]() {
+        test.fde = fdevent_create(fds[0], [](fdevent* fde, unsigned events, void* arg) {
+            auto test = static_cast<TimeoutTest*>(arg);
+            auto now = std::chrono::steady_clock::now();
+            CHECK((events & FDE_READ) ^ (events & FDE_TIMEOUT));
+            TimeoutEvent event;
+            if ((events & FDE_READ)) {
+                char buf[2];
+                ssize_t rc = adb_read(fde->fd.get(), buf, sizeof(buf));
+                if (rc == 0) {
+                    event = TimeoutEvent::done;
+                } else if (rc == 1) {
+                    event = TimeoutEvent::read;
+                } else {
+                    abort();
+                }
+            } else if ((events & FDE_TIMEOUT)) {
+                event = TimeoutEvent::timeout;
+            } else {
+                abort();
+            }
+
+            CHECK_EQ(fde, test->fde);
+            test->events.emplace_back(event, now);
+
+            if (event == TimeoutEvent::done) {
+                fdevent_destroy(fde);
+            }
+        }, &test);
+        fdevent_add(test.fde, FDE_READ);
+        fdevent_set_timeout(test.fde, delta);
+    });
+
+    ASSERT_EQ(1, adb_write(fds[1], "", 1));
+
+    // Timeout should happen here
+    std::this_thread::sleep_for(delta);
+
+    // and another.
+    std::this_thread::sleep_for(delta);
+
+    // No timeout should happen here.
+    std::this_thread::sleep_for(delta / 2);
+    adb_close(fds[1]);
+
+    TerminateThread();
+
+    ASSERT_EQ(4ULL, test.events.size());
+    ASSERT_EQ(TimeoutEvent::read, test.events[0].first);
+    ASSERT_EQ(TimeoutEvent::timeout, test.events[1].first);
+    ASSERT_EQ(TimeoutEvent::timeout, test.events[2].first);
+    ASSERT_EQ(TimeoutEvent::done, test.events[3].first);
+
+    std::vector<int> time_deltas;
+    for (size_t i = 0; i < test.events.size() - 1; ++i) {
+        auto before = test.events[i].second;
+        auto after = test.events[i + 1].second;
+        auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(after - before);
+        time_deltas.push_back(diff.count());
+    }
+
+    std::vector<int> expected = {
+        delta.count(),
+        delta.count(),
+        delta.count() / 2,
+    };
+
+    std::vector<int> diff;
+    ASSERT_EQ(time_deltas.size(), expected.size());
+    for (size_t i = 0; i < time_deltas.size(); ++i) {
+        diff.push_back(std::abs(time_deltas[i] - expected[i]));
+    }
+
+    ASSERT_LT(diff[0], delta.count() * 0.5);
+    ASSERT_LT(diff[1], delta.count() * 0.5);
+    ASSERT_LT(diff[2], delta.count() * 0.5);
+}
diff --git a/adb/socket_test.cpp b/adb/socket_test.cpp
index 7908f82..5e28f76 100644
--- a/adb/socket_test.cpp
+++ b/adb/socket_test.cpp
@@ -221,6 +221,8 @@
     EXPECT_EQ(2u + GetAdditionalLocalSocketCount(), fdevent_installed_count());
     ASSERT_EQ(0, adb_close(socket_fd[0]));
 
+    std::this_thread::sleep_for(2s);
+
     WaitForFdeventLoop();
     ASSERT_EQ(GetAdditionalLocalSocketCount(), fdevent_installed_count());
     TerminateThread();
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index f7c39f0..420a6d5 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -26,6 +26,7 @@
 #include <unistd.h>
 
 #include <algorithm>
+#include <chrono>
 #include <mutex>
 #include <string>
 #include <vector>
@@ -41,6 +42,8 @@
 #include "transport.h"
 #include "types.h"
 
+using namespace std::chrono_literals;
+
 static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
 static unsigned local_socket_next_id = 1;
 
@@ -238,16 +241,64 @@
     fdevent_add(s->fde, FDE_READ);
 }
 
+struct ClosingSocket {
+    std::chrono::steady_clock::time_point begin;
+};
+
+// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
+// socket while we have pending data, a TCP RST should be sent to the
+// other end to notify it that we didn't read all of its data. However,
+// this can result in data that we've successfully written out to be dropped
+// on the other end. To avoid this, instead of immediately closing a
+// socket, call shutdown on it instead, and then read from the file
+// descriptor until we hit EOF or an error before closing.
+static void deferred_close(unique_fd fd) {
+    // Shutdown the socket in the outgoing direction only, so that
+    // we don't have the same problem on the opposite end.
+    adb_shutdown(fd.get(), SHUT_WR);
+    auto callback = [](fdevent* fde, unsigned event, void* arg) {
+        auto socket_info = static_cast<ClosingSocket*>(arg);
+        if (event & FDE_READ) {
+            ssize_t rc;
+            char buf[BUFSIZ];
+            while ((rc = adb_read(fde->fd.get(), buf, sizeof(buf))) > 0) {
+                continue;
+            }
+
+            if (rc == -1 && errno == EAGAIN) {
+                // There's potentially more data to read.
+                auto duration = std::chrono::steady_clock::now() - socket_info->begin;
+                if (duration > 1s) {
+                    LOG(WARNING) << "timeout expired while flushing socket, closing";
+                } else {
+                    return;
+                }
+            }
+        } else if (event & FDE_TIMEOUT) {
+            LOG(WARNING) << "timeout expired while flushing socket, closing";
+        }
+
+        // Either there was an error, we hit the end of the socket, or our timeout expired.
+        fdevent_destroy(fde);
+        delete socket_info;
+    };
+
+    ClosingSocket* socket_info = new ClosingSocket{
+            .begin = std::chrono::steady_clock::now(),
+    };
+
+    fdevent* fde = fdevent_create(fd.release(), callback, socket_info);
+    fdevent_add(fde, FDE_READ);
+    fdevent_set_timeout(fde, 1s);
+}
+
 // be sure to hold the socket list lock when calling this
 static void local_socket_destroy(asocket* s) {
     int exit_on_close = s->exit_on_close;
 
     D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
 
-    /* IMPORTANT: the remove closes the fd
-    ** that belongs to this socket
-    */
-    fdevent_destroy(s->fde);
+    deferred_close(fdevent_release(s->fde));
 
     remove_socket(s);
     delete s;
diff --git a/adb/sysdeps/chrono.h b/adb/sysdeps/chrono.h
index c73a638..5c5af7c 100644
--- a/adb/sysdeps/chrono.h
+++ b/adb/sysdeps/chrono.h
@@ -18,29 +18,4 @@
 
 #include <chrono>
 
-#if defined(_WIN32)
-// We don't have C++14 on Windows yet.
-// Reimplement std::chrono_literals ourselves until we do.
-
-// Silence the following warning (which gets promoted to an error):
-// error: literal operator suffixes not preceded by ‘_’ are reserved for future standardization
-#pragma GCC system_header
-
-constexpr std::chrono::seconds operator"" s(unsigned long long s) {
-    return std::chrono::seconds(s);
-}
-
-constexpr std::chrono::duration<long double> operator"" s(long double s) {
-    return std::chrono::duration<long double>(s);
-}
-
-constexpr std::chrono::milliseconds operator"" ms(unsigned long long ms) {
-    return std::chrono::milliseconds(ms);
-}
-
-constexpr std::chrono::duration<long double, std::milli> operator"" ms(long double ms) {
-    return std::chrono::duration<long double, std::milli>(ms);
-}
-#else
 using namespace std::chrono_literals;
-#endif
diff --git a/adb/test_device.py b/adb/test_device.py
index 34f8fd9..f95a5b3 100755
--- a/adb/test_device.py
+++ b/adb/test_device.py
@@ -35,6 +35,8 @@
 import time
 import unittest
 
+from datetime import datetime
+
 import adb
 
 def requires_root(func):
@@ -1335,6 +1337,63 @@
             self.device.forward_remove("tcp:{}".format(local_port))
 
 
+class SocketTest(DeviceTest):
+    def test_socket_flush(self):
+        """Test that we handle socket closure properly.
+
+        If we're done writing to a socket, closing before the other end has
+        closed will send a TCP_RST if we have incoming data queued up, which
+        may result in data that we've written being discarded.
+
+        Bug: http://b/74616284
+        """
+        s = socket.create_connection(("localhost", 5037))
+
+        def adb_length_prefixed(string):
+            encoded = string.encode("utf8")
+            result = b"%04x%s" % (len(encoded), encoded)
+            return result
+
+        if "ANDROID_SERIAL" in os.environ:
+            transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
+        else:
+            transport_string = "host:transport-any"
+
+        s.sendall(adb_length_prefixed(transport_string))
+        response = s.recv(4)
+        self.assertEquals(b"OKAY", response)
+
+        shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
+        s.sendall(adb_length_prefixed(shell_string))
+
+        response = s.recv(4)
+        self.assertEquals(b"OKAY", response)
+
+        # Spawn a thread that dumps garbage into the socket until failure.
+        def spam():
+            buf = b"\0" * 16384
+            try:
+                while True:
+                    s.sendall(buf)
+            except Exception as ex:
+                print(ex)
+
+        thread = threading.Thread(target=spam)
+        thread.start()
+
+        time.sleep(1)
+
+        received = b""
+        while True:
+            read = s.recv(512)
+            if len(read) == 0:
+                break
+            received += read
+
+        self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
+        thread.join()
+
+
 if sys.platform == "win32":
     # From https://stackoverflow.com/a/38749458
     import os