adb: tell the client what transport it received.

Prerequisite for making `adb root` wait for the device that it told to
restart to disappear: the client needs to know which transport to wait
on.

Bug: http://b/124244488
Test: manual
Change-Id: I474559838ad7c0e961e9d2a98c902bca3b60d6c8
diff --git a/adb/adb.cpp b/adb/adb.cpp
index 32fbb65..3c07882 100644
--- a/adb/adb.cpp
+++ b/adb/adb.cpp
@@ -1018,8 +1018,9 @@
     return 0;
 }
 
-bool handle_host_request(std::string_view service, TransportType type, const char* serial,
-                        TransportId transport_id, int reply_fd, asocket* s) {
+HostRequestResult handle_host_request(std::string_view service, TransportType type,
+                                      const char* serial, TransportId transport_id, int reply_fd,
+                                      asocket* s) {
     if (service == "kill") {
         fprintf(stderr, "adb server killed by remote request\n");
         fflush(stdout);
@@ -1032,29 +1033,49 @@
         exit(0);
     }
 
-    // "transport:" is used for switching transport with a specified serial number
-    // "transport-usb:" is used for switching transport to the only USB transport
-    // "transport-local:" is used for switching transport to the only local transport
-    // "transport-any:" is used for switching transport to the only transport
-    if (service.starts_with("transport")) {
+    LOG(DEBUG) << "handle_host_request(" << service << ")";
+
+    // Transport selection:
+    if (service.starts_with("transport") || service.starts_with("tport:")) {
         TransportType type = kTransportAny;
 
         std::string serial_storage;
+        bool legacy = true;
 
-        if (ConsumePrefix(&service, "transport-id:")) {
-            if (!ParseUint(&transport_id, service)) {
-                SendFail(reply_fd, "invalid transport id");
-                return true;
+        // New transport selection protocol:
+        // This is essentially identical to the previous version, except it returns the selected
+        // transport id to the caller as well.
+        if (ConsumePrefix(&service, "tport:")) {
+            legacy = false;
+            if (ConsumePrefix(&service, "serial:")) {
+                serial_storage = service;
+                serial = serial_storage.c_str();
+            } else if (service == "usb") {
+                type = kTransportUsb;
+            } else if (service == "local") {
+                type = kTransportLocal;
+            } else if (service == "any") {
+                type = kTransportAny;
             }
-        } else if (service == "transport-usb") {
-            type = kTransportUsb;
-        } else if (service == "transport-local") {
-            type = kTransportLocal;
-        } else if (service == "transport-any") {
-            type = kTransportAny;
-        } else if (ConsumePrefix(&service, "transport:")) {
-            serial_storage = service;
-            serial = serial_storage.c_str();
+
+            // Selection by id is unimplemented, since you obviously already know the transport id
+            // you're connecting to.
+        } else {
+            if (ConsumePrefix(&service, "transport-id:")) {
+                if (!ParseUint(&transport_id, service)) {
+                    SendFail(reply_fd, "invalid transport id");
+                    return HostRequestResult::Handled;
+                }
+            } else if (service == "transport-usb") {
+                type = kTransportUsb;
+            } else if (service == "transport-local") {
+                type = kTransportLocal;
+            } else if (service == "transport-any") {
+                type = kTransportAny;
+            } else if (ConsumePrefix(&service, "transport:")) {
+                serial_storage = service;
+                serial = serial_storage.c_str();
+            }
         }
 
         std::string error;
@@ -1063,11 +1084,15 @@
             s->transport = t;
             SendOkay(reply_fd);
 
-            // We succesfully handled the device selection, but there's another request coming.
-            return false;
+            if (!legacy) {
+                // Nothing we can do if this fails.
+                WriteFdExactly(reply_fd, &t->id, sizeof(t->id));
+            }
+
+            return HostRequestResult::SwitchedTransport;
         } else {
             SendFail(reply_fd, error);
-            return true;
+            return HostRequestResult::Handled;
         }
     }
 
@@ -1078,7 +1103,7 @@
         std::string device_list = list_transports(long_listing);
         D("Sending device list...");
         SendOkay(reply_fd, device_list);
-        return true;
+        return HostRequestResult::Handled;
     }
 
     if (service == "reconnect-offline") {
@@ -1094,7 +1119,7 @@
             response.resize(response.size() - 1);
         }
         SendOkay(reply_fd, response);
-        return true;
+        return HostRequestResult::Handled;
     }
 
     if (service == "features") {
@@ -1105,7 +1130,7 @@
         } else {
             SendFail(reply_fd, error);
         }
-        return true;
+        return HostRequestResult::Handled;
     }
 
     if (service == "host-features") {
@@ -1116,7 +1141,7 @@
         }
         features.insert(kFeaturePushSync);
         SendOkay(reply_fd, FeatureSetToString(features));
-        return true;
+        return HostRequestResult::Handled;
     }
 
     // remove TCP transport
@@ -1125,7 +1150,7 @@
         if (address.empty()) {
             kick_all_tcp_devices();
             SendOkay(reply_fd, "disconnected everything");
-            return true;
+            return HostRequestResult::Handled;
         }
 
         std::string serial;
@@ -1137,22 +1162,22 @@
         } else if (!android::base::ParseNetAddress(address, &host, &port, &serial, &error)) {
             SendFail(reply_fd, android::base::StringPrintf("couldn't parse '%s': %s",
                                                            address.c_str(), error.c_str()));
-            return true;
+            return HostRequestResult::Handled;
         }
         atransport* t = find_transport(serial.c_str());
         if (t == nullptr) {
             SendFail(reply_fd, android::base::StringPrintf("no such device '%s'", serial.c_str()));
-            return true;
+            return HostRequestResult::Handled;
         }
         kick_transport(t);
         SendOkay(reply_fd, android::base::StringPrintf("disconnected %s", address.c_str()));
-        return true;
+        return HostRequestResult::Handled;
     }
 
     // Returns our value for ADB_SERVER_VERSION.
     if (service == "version") {
         SendOkay(reply_fd, android::base::StringPrintf("%04x", ADB_SERVER_VERSION));
-        return true;
+        return HostRequestResult::Handled;
     }
 
     // These always report "unknown" rather than the actual error, for scripts.
@@ -1164,7 +1189,7 @@
         } else {
             SendFail(reply_fd, error);
         }
-        return true;
+        return HostRequestResult::Handled;
     }
     if (service == "get-devpath") {
         std::string error;
@@ -1174,7 +1199,7 @@
         } else {
             SendFail(reply_fd, error);
         }
-        return true;
+        return HostRequestResult::Handled;
     }
     if (service == "get-state") {
         std::string error;
@@ -1184,7 +1209,7 @@
         } else {
             SendFail(reply_fd, error);
         }
-        return true;
+        return HostRequestResult::Handled;
     }
 
     // Indicates a new emulator instance has started.
@@ -1197,7 +1222,7 @@
         }
 
         /* we don't even need to send a reply */
-        return true;
+        return HostRequestResult::Handled;
     }
 
     if (service == "reconnect") {
@@ -1209,7 +1234,7 @@
                     "reconnecting " + t->serial_name() + " [" + t->connection_state_name() + "]\n";
         }
         SendOkay(reply_fd, response);
-        return true;
+        return HostRequestResult::Handled;
     }
 
     // TODO: Switch handle_forward_request to string_view.
@@ -1220,10 +1245,10 @@
                     return acquire_one_transport(type, serial, transport_id, nullptr, error);
                 },
                 reply_fd)) {
-        return true;
+        return HostRequestResult::Handled;
     }
 
-    return false;
+    return HostRequestResult::Unhandled;
 }
 
 static auto& init_mutex = *new std::mutex();
diff --git a/adb/adb.h b/adb/adb.h
index f575adb..5eea8be 100644
--- a/adb/adb.h
+++ b/adb/adb.h
@@ -219,8 +219,15 @@
 #define USB_FFS_ADB_IN USB_FFS_ADB_EP(ep2)
 #endif
 
-bool handle_host_request(std::string_view service, TransportType type, const char* serial,
-                         TransportId transport_id, int reply_fd, asocket* s);
+enum class HostRequestResult {
+    Handled,
+    SwitchedTransport,
+    Unhandled,
+};
+
+HostRequestResult handle_host_request(std::string_view service, TransportType type,
+                                      const char* serial, TransportId transport_id, int reply_fd,
+                                      asocket* s);
 
 void handle_online(atransport* t);
 void handle_offline(atransport* t);
diff --git a/adb/client/adb_client.cpp b/adb/client/adb_client.cpp
index 0a09d1e..4cf3a74 100644
--- a/adb/client/adb_client.cpp
+++ b/adb/client/adb_client.cpp
@@ -70,46 +70,60 @@
     __adb_server_socket_spec = socket_spec;
 }
 
-static int switch_socket_transport(int fd, std::string* error) {
+static std::optional<TransportId> switch_socket_transport(int fd, std::string* error) {
+    TransportId result;
+    bool read_transport = true;
+
     std::string service;
     if (__adb_transport_id) {
+        read_transport = false;
         service += "host:transport-id:";
         service += std::to_string(__adb_transport_id);
+        result = __adb_transport_id;
     } else if (__adb_serial) {
-        service += "host:transport:";
+        service += "host:tport:serial:";
         service += __adb_serial;
     } else {
         const char* transport_type = "???";
         switch (__adb_transport) {
           case kTransportUsb:
-            transport_type = "transport-usb";
-            break;
+              transport_type = "usb";
+              break;
           case kTransportLocal:
-            transport_type = "transport-local";
-            break;
+              transport_type = "local";
+              break;
           case kTransportAny:
-            transport_type = "transport-any";
-            break;
+              transport_type = "any";
+              break;
           case kTransportHost:
             // no switch necessary
             return 0;
         }
-        service += "host:";
+        service += "host:tport:";
         service += transport_type;
     }
 
     if (!SendProtocolString(fd, service)) {
         *error = perror_str("write failure during connection");
-        return -1;
+        return std::nullopt;
     }
-    D("Switch transport in progress");
+
+    LOG(DEBUG) << "Switch transport in progress: " << service;
 
     if (!adb_status(fd, error)) {
         D("Switch transport failed: %s", error->c_str());
-        return -1;
+        return std::nullopt;
     }
+
+    if (read_transport) {
+        if (!ReadFdExactly(fd, &result, sizeof(result))) {
+            *error = "failed to read transport id from server";
+            return std::nullopt;
+        }
+    }
+
     D("Switch transport success");
-    return 0;
+    return result;
 }
 
 bool adb_status(int fd, std::string* error) {
@@ -133,11 +147,10 @@
     return false;
 }
 
-static int _adb_connect(const std::string& service, std::string* error) {
-    D("_adb_connect: %s", service.c_str());
+static int _adb_connect(std::string_view service, TransportId* transport, std::string* error) {
+    LOG(DEBUG) << "_adb_connect: " << service;
     if (service.empty() || service.size() > MAX_PAYLOAD) {
-        *error = android::base::StringPrintf("bad service name length (%zd)",
-                                             service.size());
+        *error = android::base::StringPrintf("bad service name length (%zd)", service.size());
         return -1;
     }
 
@@ -149,8 +162,15 @@
         return -2;
     }
 
-    if (memcmp(&service[0], "host", 4) != 0 && switch_socket_transport(fd.get(), error)) {
-        return -1;
+    if (!service.starts_with("host")) {
+        std::optional<TransportId> transport_result = switch_socket_transport(fd.get(), error);
+        if (!transport_result) {
+            return -1;
+        }
+
+        if (transport) {
+            *transport = *transport_result;
+        }
     }
 
     if (!SendProtocolString(fd.get(), service)) {
@@ -190,11 +210,15 @@
     return true;
 }
 
-int adb_connect(const std::string& service, std::string* error) {
-    // first query the adb server's version
-    unique_fd fd(_adb_connect("host:version", error));
+int adb_connect(std::string_view service, std::string* error) {
+    return adb_connect(nullptr, service, error);
+}
 
-    D("adb_connect: service %s", service.c_str());
+int adb_connect(TransportId* transport, std::string_view service, std::string* error) {
+    // first query the adb server's version
+    unique_fd fd(_adb_connect("host:version", nullptr, error));
+
+    LOG(DEBUG) << "adb_connect: service: " << service;
     if (fd == -2 && !is_local_socket_spec(__adb_server_socket_spec)) {
         fprintf(stderr, "* cannot start server on remote host\n");
         // error is the original network connection error
@@ -216,7 +240,7 @@
         // Fall through to _adb_connect.
     } else {
         // If a server is already running, check its version matches.
-        int version = ADB_SERVER_VERSION - 1;
+        int version = 0;
 
         // If we have a file descriptor, then parse version result.
         if (fd >= 0) {
@@ -254,7 +278,7 @@
         return 0;
     }
 
-    fd.reset(_adb_connect(service, error));
+    fd.reset(_adb_connect(service, transport, error));
     if (fd == -1) {
         D("_adb_connect error: %s", error->c_str());
     } else if(fd == -2) {
@@ -265,7 +289,6 @@
     return fd.release();
 }
 
-
 bool adb_command(const std::string& service) {
     std::string error;
     unique_fd fd(adb_connect(service, &error));
diff --git a/adb/client/adb_client.h b/adb/client/adb_client.h
index d467539..0a73787 100644
--- a/adb/client/adb_client.h
+++ b/adb/client/adb_client.h
@@ -24,7 +24,10 @@
 
 // Connect to adb, connect to the named service, and return a valid fd for
 // interacting with that service upon success or a negative number on failure.
-int adb_connect(const std::string& service, std::string* _Nonnull error);
+int adb_connect(std::string_view service, std::string* _Nonnull error);
+
+// Same as above, except returning the TransportId for the service that we've connected to.
+int adb_connect(TransportId* _Nullable id, std::string_view service, std::string* _Nonnull error);
 
 // Kill the currently running adb server, if it exists.
 bool adb_kill_server();
diff --git a/adb/sockets.cpp b/adb/sockets.cpp
index 04d92db..dc44026 100644
--- a/adb/sockets.cpp
+++ b/adb/sockets.cpp
@@ -792,16 +792,22 @@
 
         // Some requests are handled immediately -- in that case the handle_host_request() routine
         // has sent the OKAY or FAIL message and all we have to do is clean up.
-        if (handle_host_request(service, type,
-                                serial.empty() ? nullptr : std::string(serial).c_str(),
-                                transport_id, s->peer->fd, s)) {
-            LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'";
-            goto fail;
-        }
-        if (service.starts_with("transport")) {
-            D("SS(%d): okay transport", s->id);
-            s->smart_socket_data.clear();
-            return 0;
+        auto host_request_result = handle_host_request(
+                service, type, serial.empty() ? nullptr : std::string(serial).c_str(), transport_id,
+                s->peer->fd, s);
+
+        switch (host_request_result) {
+            case HostRequestResult::Handled:
+                LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'";
+                goto fail;
+
+            case HostRequestResult::SwitchedTransport:
+                D("SS(%d): okay transport", s->id);
+                s->smart_socket_data.clear();
+                return 0;
+
+            case HostRequestResult::Unhandled:
+                break;
         }
 
         /* try to find a local service with this name.