adb: don't close sockets before hitting EOF.
Reimplement commit ffc11d3cf3643c24780f902386be61851531f022 using
fdevent. The previous attempt was reverted because we were blindly
continuing when revents & POLLIN == 0, which ignored POLLHUP/POLLERR,
leading to spinloops when the opposite end of the file descriptor was
shutdown when we had no data left to read.
This patch reimplements the functionality implemented by that commit
using fdevent, which gets us detection of spin loops for free.
Bug: http://b/74616284
Test: ./test_device.py
Change-Id: I1abd671fef4c29e99dad968aa66bb754ca382578
diff --git a/adb/socket_test.cpp b/adb/socket_test.cpp
index 7908f82..5e28f76 100644
--- a/adb/socket_test.cpp
+++ b/adb/socket_test.cpp
@@ -221,6 +221,8 @@
EXPECT_EQ(2u + GetAdditionalLocalSocketCount(), fdevent_installed_count());
ASSERT_EQ(0, adb_close(socket_fd[0]));
+ std::this_thread::sleep_for(2s);
+
WaitForFdeventLoop();
ASSERT_EQ(GetAdditionalLocalSocketCount(), fdevent_installed_count());
TerminateThread();
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index f7c39f0..420a6d5 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -26,6 +26,7 @@
#include <unistd.h>
#include <algorithm>
+#include <chrono>
#include <mutex>
#include <string>
#include <vector>
@@ -41,6 +42,8 @@
#include "transport.h"
#include "types.h"
+using namespace std::chrono_literals;
+
static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
static unsigned local_socket_next_id = 1;
@@ -238,16 +241,64 @@
fdevent_add(s->fde, FDE_READ);
}
+struct ClosingSocket {
+ std::chrono::steady_clock::time_point begin;
+};
+
+// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
+// socket while we have pending data, a TCP RST should be sent to the
+// other end to notify it that we didn't read all of its data. However,
+// this can result in data that we've successfully written out to be dropped
+// on the other end. To avoid this, instead of immediately closing a
+// socket, call shutdown on it instead, and then read from the file
+// descriptor until we hit EOF or an error before closing.
+static void deferred_close(unique_fd fd) {
+ // Shutdown the socket in the outgoing direction only, so that
+ // we don't have the same problem on the opposite end.
+ adb_shutdown(fd.get(), SHUT_WR);
+ auto callback = [](fdevent* fde, unsigned event, void* arg) {
+ auto socket_info = static_cast<ClosingSocket*>(arg);
+ if (event & FDE_READ) {
+ ssize_t rc;
+ char buf[BUFSIZ];
+ while ((rc = adb_read(fde->fd.get(), buf, sizeof(buf))) > 0) {
+ continue;
+ }
+
+ if (rc == -1 && errno == EAGAIN) {
+ // There's potentially more data to read.
+ auto duration = std::chrono::steady_clock::now() - socket_info->begin;
+ if (duration > 1s) {
+ LOG(WARNING) << "timeout expired while flushing socket, closing";
+ } else {
+ return;
+ }
+ }
+ } else if (event & FDE_TIMEOUT) {
+ LOG(WARNING) << "timeout expired while flushing socket, closing";
+ }
+
+ // Either there was an error, we hit the end of the socket, or our timeout expired.
+ fdevent_destroy(fde);
+ delete socket_info;
+ };
+
+ ClosingSocket* socket_info = new ClosingSocket{
+ .begin = std::chrono::steady_clock::now(),
+ };
+
+ fdevent* fde = fdevent_create(fd.release(), callback, socket_info);
+ fdevent_add(fde, FDE_READ);
+ fdevent_set_timeout(fde, 1s);
+}
+
// be sure to hold the socket list lock when calling this
static void local_socket_destroy(asocket* s) {
int exit_on_close = s->exit_on_close;
D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
- /* IMPORTANT: the remove closes the fd
- ** that belongs to this socket
- */
- fdevent_destroy(s->fde);
+ deferred_close(fdevent_release(s->fde));
remove_socket(s);
delete s;
diff --git a/adb/test_device.py b/adb/test_device.py
index 34f8fd9..f95a5b3 100755
--- a/adb/test_device.py
+++ b/adb/test_device.py
@@ -35,6 +35,8 @@
import time
import unittest
+from datetime import datetime
+
import adb
def requires_root(func):
@@ -1335,6 +1337,63 @@
self.device.forward_remove("tcp:{}".format(local_port))
+class SocketTest(DeviceTest):
+ def test_socket_flush(self):
+ """Test that we handle socket closure properly.
+
+ If we're done writing to a socket, closing before the other end has
+ closed will send a TCP_RST if we have incoming data queued up, which
+ may result in data that we've written being discarded.
+
+ Bug: http://b/74616284
+ """
+ s = socket.create_connection(("localhost", 5037))
+
+ def adb_length_prefixed(string):
+ encoded = string.encode("utf8")
+ result = b"%04x%s" % (len(encoded), encoded)
+ return result
+
+ if "ANDROID_SERIAL" in os.environ:
+ transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
+ else:
+ transport_string = "host:transport-any"
+
+ s.sendall(adb_length_prefixed(transport_string))
+ response = s.recv(4)
+ self.assertEquals(b"OKAY", response)
+
+ shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
+ s.sendall(adb_length_prefixed(shell_string))
+
+ response = s.recv(4)
+ self.assertEquals(b"OKAY", response)
+
+ # Spawn a thread that dumps garbage into the socket until failure.
+ def spam():
+ buf = b"\0" * 16384
+ try:
+ while True:
+ s.sendall(buf)
+ except Exception as ex:
+ print(ex)
+
+ thread = threading.Thread(target=spam)
+ thread.start()
+
+ time.sleep(1)
+
+ received = b""
+ while True:
+ read = s.recv(512)
+ if len(read) == 0:
+ break
+ received += read
+
+ self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
+ thread.join()
+
+
if sys.platform == "win32":
# From https://stackoverflow.com/a/38749458
import os