Merge "adb: don't close sockets before hitting EOF."
diff --git a/adb/Android.bp b/adb/Android.bp
index bccc71a..7f82ca6 100644
--- a/adb/Android.bp
+++ b/adb/Android.bp
@@ -24,6 +24,7 @@
         "-Wno-missing-field-initializers",
         "-Wvla",
     ],
+    cpp_std: "gnu++17",
     rtti: true,
 
     use_version_lib: true,
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index 1534792..dfd9a0a 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -26,10 +26,14 @@
 #include <unistd.h>
 
 #include <algorithm>
+#include <map>
 #include <mutex>
 #include <string>
+#include <thread>
 #include <vector>
 
+#include <android-base/thread_annotations.h>
+
 #if !ADB_HOST
 #include <android-base/properties.h>
 #include <log/log_properties.h>
@@ -37,9 +41,150 @@
 
 #include "adb.h"
 #include "adb_io.h"
+#include "adb_utils.h"
+#include "sysdeps/chrono.h"
 #include "transport.h"
 #include "types.h"
 
+// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
+// socket while we have pending data, a TCP RST should be sent to the
+// other end to notify it that we didn't read all of its data. However,
+// this can result in data that we've successfully written out to be dropped
+// on the other end. To avoid this, instead of immediately closing a
+// socket, call shutdown on it instead, and then read from the file
+// descriptor until we hit EOF or an error before closing.
+struct LingeringSocketCloser {
+    LingeringSocketCloser() = default;
+    ~LingeringSocketCloser() = delete;
+
+    // Defer thread creation until it's needed, because we need for there to
+    // only be one thread when dropping privileges in adbd.
+    void Start() {
+        CHECK(!thread_.joinable());
+
+        int fds[2];
+        if (adb_socketpair(fds) != 0) {
+            PLOG(FATAL) << "adb_socketpair failed";
+        }
+
+        set_file_block_mode(fds[0], false);
+        set_file_block_mode(fds[1], false);
+
+        notify_fd_read_.reset(fds[0]);
+        notify_fd_write_.reset(fds[1]);
+
+        thread_ = std::thread([this]() { Run(); });
+    }
+
+    void EnqueueSocket(unique_fd socket) {
+        // Shutdown the socket in the outgoing direction only, so that
+        // we don't have the same problem on the opposite end.
+        adb_shutdown(socket.get(), SHUT_WR);
+        set_file_block_mode(socket.get(), false);
+
+        std::lock_guard<std::mutex> lock(mutex_);
+        int fd = socket.get();
+        SocketInfo info = {
+                .fd = std::move(socket),
+                .deadline = std::chrono::steady_clock::now() + 1s,
+        };
+
+        D("LingeringSocketCloser received fd %d", fd);
+
+        fds_.emplace(fd, std::move(info));
+        if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
+            PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
+        }
+    }
+
+  private:
+    std::vector<adb_pollfd> GeneratePollFds() {
+        std::lock_guard<std::mutex> lock(mutex_);
+        std::vector<adb_pollfd> result;
+        result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
+        for (auto& [fd, _] : fds_) {
+            result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
+        }
+        return result;
+    }
+
+    void Run() {
+        while (true) {
+            std::vector<adb_pollfd> pfds = GeneratePollFds();
+            int rc = adb_poll(pfds.data(), pfds.size(), 1000);
+            if (rc == -1) {
+                PLOG(FATAL) << "poll failed in LingeringSocketCloser";
+            }
+
+            std::lock_guard<std::mutex> lock(mutex_);
+            if (rc == 0) {
+                // Check deadlines.
+                auto now = std::chrono::steady_clock::now();
+                for (auto it = fds_.begin(); it != fds_.end();) {
+                    if (now > it->second.deadline) {
+                        D("LingeringSocketCloser closing fd %d due to deadline", it->first);
+                        it = fds_.erase(it);
+                    } else {
+                        D("deadline still not expired for fd %d", it->first);
+                        ++it;
+                    }
+                }
+                continue;
+            }
+
+            for (auto& pfd : pfds) {
+                if ((pfd.revents & POLLIN) == 0) {
+                    continue;
+                }
+
+                // Empty the fd.
+                ssize_t rc;
+                char buf[32768];
+                while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
+                    continue;
+                }
+
+                if (pfd.fd == notify_fd_read_) {
+                    continue;
+                }
+
+                auto it = fds_.find(pfd.fd);
+                if (it == fds_.end()) {
+                    LOG(FATAL) << "fd is missing";
+                }
+
+                if (rc == -1 && errno == EAGAIN) {
+                    if (std::chrono::steady_clock::now() > it->second.deadline) {
+                        D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
+                    } else {
+                        continue;
+                    }
+                } else if (rc == -1) {
+                    D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
+                } else {
+                    D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
+                }
+
+                fds_.erase(it);
+            }
+        }
+    }
+
+    std::thread thread_;
+    unique_fd notify_fd_read_;
+    unique_fd notify_fd_write_;
+
+    struct SocketInfo {
+        unique_fd fd;
+        std::chrono::steady_clock::time_point deadline;
+    };
+
+    std::mutex mutex_;
+    std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
+};
+
+static auto& socket_closer = *new LingeringSocketCloser();
+
 static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
 static unsigned local_socket_next_id = 1;
 
@@ -243,10 +388,12 @@
 
     D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
 
-    /* IMPORTANT: the remove closes the fd
-    ** that belongs to this socket
-    */
-    fdevent_destroy(s->fde);
+    // Defer thread creation until it's needed, because we need for there to
+    // only be one thread when dropping privileges in adbd.
+    static std::once_flag once;
+    std::call_once(once, []() { socket_closer.Start(); });
+
+    socket_closer.EnqueueSocket(fdevent_release(s->fde));
 
     remove_socket(s);
     delete s;
diff --git a/adb/test_device.py b/adb/test_device.py
index c3166ff..4c45a73 100755
--- a/adb/test_device.py
+++ b/adb/test_device.py
@@ -35,6 +35,8 @@
 import time
 import unittest
 
+from datetime import datetime
+
 import adb
 
 def requires_root(func):
@@ -1335,6 +1337,63 @@
             self.device.forward_remove("tcp:{}".format(local_port))
 
 
+class SocketTest(DeviceTest):
+    def test_socket_flush(self):
+        """Test that we handle socket closure properly.
+
+        If we're done writing to a socket, closing before the other end has
+        closed will send a TCP_RST if we have incoming data queued up, which
+        may result in data that we've written being discarded.
+
+        Bug: http://b/74616284
+        """
+        s = socket.create_connection(("localhost", 5037))
+
+        def adb_length_prefixed(string):
+            encoded = string.encode("utf8")
+            result = b"%04x%s" % (len(encoded), encoded)
+            return result
+
+        if "ANDROID_SERIAL" in os.environ:
+            transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
+        else:
+            transport_string = "host:transport-any"
+
+        s.sendall(adb_length_prefixed(transport_string))
+        response = s.recv(4)
+        self.assertEquals(b"OKAY", response)
+
+        shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
+        s.sendall(adb_length_prefixed(shell_string))
+
+        response = s.recv(4)
+        self.assertEquals(b"OKAY", response)
+
+        # Spawn a thread that dumps garbage into the socket until failure.
+        def spam():
+            buf = b"\0" * 16384
+            try:
+                while True:
+                    s.sendall(buf)
+            except Exception as ex:
+                print(ex)
+
+        thread = threading.Thread(target=spam)
+        thread.start()
+
+        time.sleep(1)
+
+        received = b""
+        while True:
+            read = s.recv(512)
+            if len(read) == 0:
+                break
+            received += read
+
+        self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
+        thread.join()
+
+
 if sys.platform == "win32":
     # From https://stackoverflow.com/a/38749458
     import os