Merge "adb: fix use after free of atransport." am: 1cdcc5f7e8 am: 15f0402787 am: f17e6bb38e

Change-Id: Ia213e96166a8f0a68e2f6fe0fda3adb0fa4da640
diff --git a/adb/daemon/auth.cpp b/adb/daemon/auth.cpp
index 2e84ce6..ec4ab4a 100644
--- a/adb/daemon/auth.cpp
+++ b/adb/daemon/auth.cpp
@@ -16,36 +16,72 @@
 
 #define TRACE_TAG AUTH
 
-#include "adb.h"
-#include "adb_auth.h"
-#include "adb_io.h"
-#include "fdevent/fdevent.h"
 #include "sysdeps.h"
-#include "transport.h"
 
 #include <resolv.h>
 #include <stdio.h>
 #include <string.h>
-#include <iomanip>
 
 #include <algorithm>
+#include <iomanip>
+#include <map>
 #include <memory>
 
 #include <adbd_auth.h>
 #include <android-base/file.h>
+#include <android-base/no_destructor.h>
 #include <android-base/strings.h>
 #include <crypto_utils/android_pubkey.h>
 #include <openssl/obj_mac.h>
 #include <openssl/rsa.h>
 #include <openssl/sha.h>
 
+#include "adb.h"
+#include "adb_auth.h"
+#include "adb_io.h"
+#include "fdevent/fdevent.h"
+#include "transport.h"
+#include "types.h"
+
 static AdbdAuthContext* auth_ctx;
 
 static void adb_disconnected(void* unused, atransport* t);
 static struct adisconnect adb_disconnect = {adb_disconnected, nullptr};
 
+static android::base::NoDestructor<std::map<uint32_t, weak_ptr<atransport>>> transports;
+static uint32_t transport_auth_id = 0;
+
 bool auth_required = true;
 
+static void* transport_to_callback_arg(atransport* transport) {
+    uint32_t id = transport_auth_id++;
+    (*transports)[id] = transport->weak();
+    return reinterpret_cast<void*>(id);
+}
+
+static atransport* transport_from_callback_arg(void* id) {
+    uint64_t id_u64 = reinterpret_cast<uint64_t>(id);
+    if (id_u64 > std::numeric_limits<uint32_t>::max()) {
+        LOG(FATAL) << "transport_from_callback_arg called on out of range value: " << id_u64;
+    }
+
+    uint32_t id_u32 = static_cast<uint32_t>(id_u64);
+    auto it = transports->find(id_u32);
+    if (it == transports->end()) {
+        LOG(ERROR) << "transport_from_callback_arg failed to find transport for id " << id_u32;
+        return nullptr;
+    }
+
+    atransport* t = it->second.get();
+    if (!t) {
+        LOG(WARNING) << "transport_from_callback_arg found already destructed transport";
+        return nullptr;
+    }
+
+    transports->erase(it);
+    return t;
+}
+
 static void IteratePublicKeys(std::function<bool(std::string_view public_key)> f) {
     adbd_auth_get_public_keys(
             auth_ctx,
@@ -111,9 +147,16 @@
 
 static void adbd_auth_key_authorized(void* arg, uint64_t id) {
     LOG(INFO) << "adb client authorized";
-    auto* transport = static_cast<atransport*>(arg);
-    transport->auth_id = id;
-    adbd_auth_verified(transport);
+    fdevent_run_on_main_thread([=]() {
+        LOG(INFO) << "arg = " << reinterpret_cast<uintptr_t>(arg);
+        auto* transport = transport_from_callback_arg(arg);
+        if (!transport) {
+            LOG(ERROR) << "authorization received for deleted transport, ignoring";
+            return;
+        }
+        transport->auth_id = id;
+        adbd_auth_verified(transport);
+    });
 }
 
 void adbd_auth_init(void) {
@@ -158,7 +201,8 @@
 void adbd_auth_confirm_key(atransport* t) {
     LOG(INFO) << "prompting user to authorize key";
     t->AddDisconnect(&adb_disconnect);
-    adbd_auth_prompt_user(auth_ctx, t->auth_key.data(), t->auth_key.size(), t);
+    adbd_auth_prompt_user(auth_ctx, t->auth_key.data(), t->auth_key.size(),
+                          transport_to_callback_arg(t));
 }
 
 void adbd_notify_framework_connected_key(atransport* t) {
diff --git a/adb/transport.h b/adb/transport.h
index 569e8bb..5a750ee 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -38,6 +38,7 @@
 
 #include "adb.h"
 #include "adb_unique_fd.h"
+#include "types.h"
 #include "usb.h"
 
 typedef std::unordered_set<std::string> FeatureSet;
@@ -223,7 +224,7 @@
     Abort,
 };
 
-class atransport {
+class atransport : public enable_weak_from_this<atransport> {
   public:
     // TODO(danalbert): We expose waaaaaaay too much stuff because this was
     // historically just a struct, but making the whole thing a more idiomatic
@@ -246,7 +247,7 @@
     }
     atransport(ConnectionState state = kCsOffline)
         : atransport([](atransport*) { return ReconnectResult::Abort; }, state) {}
-    virtual ~atransport();
+    ~atransport();
 
     int Write(apacket* p);
     void Reset();
diff --git a/adb/types.h b/adb/types.h
index 6b00224..c619fff 100644
--- a/adb/types.h
+++ b/adb/types.h
@@ -25,6 +25,7 @@
 
 #include <android-base/logging.h>
 
+#include "fdevent/fdevent.h"
 #include "sysdeps/uio.h"
 
 // Essentially std::vector<char>, except without zero initialization or reallocation.
@@ -245,3 +246,97 @@
     size_t start_index_ = 0;
     std::vector<block_type> chain_;
 };
+
+// An implementation of weak pointers tied to the fdevent run loop.
+//
+// This allows for code to submit a request for an object, and upon receiving
+// a response, know whether the object is still alive, or has been destroyed
+// because of other reasons. We keep a list of living weak_ptrs in each object,
+// and clear the weak_ptrs when the object is destroyed. This is safe, because
+// we require that both the destructor of the referent and the get method on
+// the weak_ptr are executed on the main thread.
+template <typename T>
+struct enable_weak_from_this;
+
+template <typename T>
+struct weak_ptr {
+    weak_ptr() = default;
+    explicit weak_ptr(T* ptr) { reset(ptr); }
+    weak_ptr(const weak_ptr& copy) { reset(copy.get()); }
+
+    weak_ptr(weak_ptr&& move) {
+        reset(move.get());
+        move.reset();
+    }
+
+    ~weak_ptr() { reset(); }
+
+    weak_ptr& operator=(const weak_ptr& copy) {
+        if (&copy == this) {
+            return *this;
+        }
+
+        reset(copy.get());
+        return *this;
+    }
+
+    weak_ptr& operator=(weak_ptr&& move) {
+        if (&move == this) {
+            return *this;
+        }
+
+        reset(move.get());
+        move.reset();
+        return *this;
+    }
+
+    T* get() {
+        check_main_thread();
+        return ptr_;
+    }
+
+    void reset(T* ptr = nullptr) {
+        check_main_thread();
+
+        if (ptr == ptr_) {
+            return;
+        }
+
+        if (ptr_) {
+            ptr_->weak_ptrs_.erase(
+                    std::remove(ptr_->weak_ptrs_.begin(), ptr_->weak_ptrs_.end(), this));
+        }
+
+        ptr_ = ptr;
+        if (ptr_) {
+            ptr_->weak_ptrs_.push_back(this);
+        }
+    }
+
+  private:
+    friend struct enable_weak_from_this<T>;
+    T* ptr_ = nullptr;
+};
+
+template <typename T>
+struct enable_weak_from_this {
+    ~enable_weak_from_this() {
+        if (!weak_ptrs_.empty()) {
+            check_main_thread();
+            for (auto& weak : weak_ptrs_) {
+                weak->ptr_ = nullptr;
+            }
+            weak_ptrs_.clear();
+        }
+    }
+
+    weak_ptr<T> weak() { return weak_ptr<T>(static_cast<T*>(this)); }
+
+    void schedule_deletion() {
+        fdevent_run_on_main_thread([this]() { delete this; });
+    }
+
+  private:
+    friend struct weak_ptr<T>;
+    std::vector<weak_ptr<T>*> weak_ptrs_;
+};