adb: make disconnect stop reconnection immediately.

Make `adb disconnect` remove transports immediately, instead of on
their next reconnection cycle.

Test: adb connect unreachable:12345; adb devices; adb disconnect; adb devices
Change-Id: I35c8b57344e847575596d09216fc636be47dde64
diff --git a/adb/test_adb.py b/adb/test_adb.py
index d4c98e4..86c13d0 100755
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -28,6 +28,7 @@
 import struct
 import subprocess
 import threading
+import time
 import unittest
 
 
@@ -90,7 +91,7 @@
     server_thread.start()
 
     try:
-        yield port
+        yield port, writesock
     finally:
         writesock.close()
         server_thread.join()
@@ -120,7 +121,7 @@
 def adb_server():
     """Context manager for an ADB server.
 
-    This creates an ADB server and returns the port it"s listening on.
+    This creates an ADB server and returns the port it's listening on.
     """
 
     port = 5038
@@ -342,7 +343,7 @@
         Bug: http://b/78991667
         """
         with adb_server() as server_port:
-            with fake_adbd() as port:
+            with fake_adbd() as (port, _):
                 serial = "emulator-{}".format(port - 1)
                 # Ensure that the emulator is not there.
                 try:
@@ -380,7 +381,7 @@
         """
         for protocol in (socket.AF_INET, socket.AF_INET6):
             try:
-                with fake_adbd(protocol=protocol) as port:
+                with fake_adbd(protocol=protocol) as (port, _):
                     serial = "localhost:{}".format(port)
                     with adb_connect(self, serial):
                         pass
@@ -391,7 +392,7 @@
     def test_already_connected(self):
         """Ensure that an already-connected device stays connected."""
 
-        with fake_adbd() as port:
+        with fake_adbd() as (port, _):
             serial = "localhost:{}".format(port)
             with adb_connect(self, serial):
                 # b/31250450: this always returns 0 but probably shouldn't.
@@ -403,7 +404,7 @@
     def test_reconnect(self):
         """Ensure that a disconnected device reconnects."""
 
-        with fake_adbd() as port:
+        with fake_adbd() as (port, _):
             serial = "localhost:{}".format(port)
             with adb_connect(self, serial):
                 output = subprocess.check_output(["adb", "-s", serial,
@@ -439,6 +440,46 @@
                         "error: device '{}' not found".format(serial).encode("utf8"))
 
 
+class DisconnectionTest(unittest.TestCase):
+    """Tests for adb disconnect."""
+
+    def test_disconnect(self):
+        """Ensure that `adb disconnect` takes effect immediately."""
+
+        def _devices(port):
+            output = subprocess.check_output(["adb", "-P", str(port), "devices"])
+            return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]]
+
+        with adb_server() as server_port:
+            with fake_adbd() as (port, sock):
+                device_name = "localhost:{}".format(port)
+                output = subprocess.check_output(["adb", "-P", str(server_port),
+                                                  "connect", device_name])
+                self.assertEqual(output.strip(),
+                                  "connected to {}".format(device_name).encode("utf8"))
+
+
+                self.assertEqual(_devices(server_port), [[device_name, "device"]])
+
+                # Send a deliberately malformed packet to make the device go offline.
+                packet = struct.pack("IIIIII", 0, 0, 0, 0, 0, 0)
+                sock.sendall(packet)
+
+                # Wait a bit.
+                time.sleep(0.1)
+
+                self.assertEqual(_devices(server_port), [[device_name, "offline"]])
+
+                # Disconnect the device.
+                output = subprocess.check_output(["adb", "-P", str(server_port),
+                                                  "disconnect", device_name])
+
+                # Wait a bit.
+                time.sleep(0.1)
+
+                self.assertEqual(_devices(server_port), [])
+
+
 def main():
     """Main entrypoint."""
     random.seed(0)
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 4cd19f9..332e0f8 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -97,6 +97,9 @@
     // Adds the atransport* to the queue of reconnect attempts.
     void TrackTransport(atransport* transport);
 
+    // Wake up the ReconnectHandler thread to have it check for kicked transports.
+    void CheckForKicked();
+
   private:
     // The main thread loop.
     void Run();
@@ -166,6 +169,10 @@
     reconnect_cv_.notify_one();
 }
 
+void ReconnectHandler::CheckForKicked() {
+    reconnect_cv_.notify_one();
+}
+
 void ReconnectHandler::Run() {
     while (true) {
         ReconnectAttempt attempt;
@@ -184,10 +191,25 @@
             }
 
             if (!running_) return;
+
+            // Scan the whole list for kicked transports, so that we immediately handle an explicit
+            // disconnect request.
+            bool kicked = false;
+            for (auto it = reconnect_queue_.begin(); it != reconnect_queue_.end();) {
+                if (it->transport->kicked()) {
+                    D("transport %s was kicked. giving up on it.", it->transport->serial.c_str());
+                    remove_transport(it->transport);
+                    it = reconnect_queue_.erase(it);
+                } else {
+                    ++it;
+                }
+                kicked = true;
+            }
+
             if (reconnect_queue_.empty()) continue;
 
-            // Go back to sleep in case |reconnect_cv_| woke up spuriously and we still
-            // have more time to wait for the current attempt.
+            // Go back to sleep if we either woke up spuriously, or we were woken up to remove
+            // a kicked transport, and the first transport isn't ready for reconnection yet.
             auto now = std::chrono::steady_clock::now();
             if (reconnect_queue_.begin()->reconnect_time > now) {
                 continue;
@@ -195,11 +217,6 @@
 
             attempt = *reconnect_queue_.begin();
             reconnect_queue_.erase(reconnect_queue_.begin());
-            if (attempt.transport->kicked()) {
-                D("transport %s was kicked. giving up on it.", attempt.transport->serial.c_str());
-                remove_transport(attempt.transport);
-                continue;
-            }
         }
         D("attempting to reconnect %s", attempt.transport->serial.c_str());
 
@@ -448,6 +465,10 @@
     if (std::find(transport_list.begin(), transport_list.end(), t) != transport_list.end()) {
         t->Kick();
     }
+
+#if ADB_HOST
+    reconnect_handler.CheckForKicked();
+#endif
 }
 
 static int transport_registration_send = -1;
@@ -1276,6 +1297,9 @@
             t->Kick();
         }
     }
+#if ADB_HOST
+    reconnect_handler.CheckForKicked();
+#endif
 }
 
 #endif