Merge "Some minor fixes to libadb_tls_connection." am: 8f04b0ca58 am: 7a15d5c88f

Change-Id: I32238ce9abd608ff290d903df5b75fdd81d49cae
diff --git a/adb/tls/adb_ca_list.cpp b/adb/tls/adb_ca_list.cpp
index 8d37bbe..36afe42 100644
--- a/adb/tls/adb_ca_list.cpp
+++ b/adb/tls/adb_ca_list.cpp
@@ -32,13 +32,13 @@
 // CA issuer identifier to distinguished embedded keys. Also has version
 // information appended to the end of the string (e.g. "AdbKey-0").
 static constexpr int kAdbKeyIdentifierNid = NID_organizationName;
-static constexpr char kAdbKeyIdentifierPrefix[] = "AdbKey-";
-static constexpr int kAdbKeyVersion = 0;
+static constexpr char kAdbKeyIdentifierV0[] = "AdbKey-0";
 
 // Where we store the actual data
 static constexpr int kAdbKeyValueNid = NID_commonName;
 
 // TODO: Remove this once X509_NAME_add_entry_by_NID is fixed to use const unsigned char*
+// https://boringssl-review.googlesource.com/c/boringssl/+/39764
 int X509_NAME_add_entry_by_NID_const(X509_NAME* name, int nid, int type, const unsigned char* bytes,
                                      int len, int loc, int set) {
     return X509_NAME_add_entry_by_NID(name, nid, type, const_cast<unsigned char*>(bytes), len, loc,
@@ -55,13 +55,13 @@
     // |len| is the len of the text excluding the final null
     int len = X509_NAME_get_text_by_NID(name, nid, nullptr, -1);
     if (len <= 0) {
-        return {};
+        return std::nullopt;
     }
 
     // Include the space for the final null byte
     std::vector<char> buf(len + 1, '\0');
     CHECK(X509_NAME_get_text_by_NID(name, nid, buf.data(), buf.size()));
-    return buf.data();
+    return std::make_optional(std::string(buf.data()));
 }
 
 }  // namespace
@@ -73,8 +73,7 @@
     // "O=AdbKey-0;CN=<key>;"
     CHECK(!key.empty());
 
-    std::string identifier = kAdbKeyIdentifierPrefix;
-    identifier += std::to_string(kAdbKeyVersion);
+    std::string identifier = kAdbKeyIdentifierV0;
     bssl::UniquePtr<X509_NAME> name(X509_NAME_new());
     CHECK(X509_NAME_add_entry_by_NID_const(name.get(), kAdbKeyIdentifierNid, MBSTRING_ASC,
                                            reinterpret_cast<const uint8_t*>(identifier.data()),
@@ -91,27 +90,34 @@
     CHECK(issuer);
 
     auto buf = GetX509NameTextByNid(issuer, kAdbKeyIdentifierNid);
-    if (!buf || !android::base::StartsWith(*buf, kAdbKeyIdentifierPrefix)) {
-        return {};
+    if (!buf) {
+        return std::nullopt;
     }
 
-    return GetX509NameTextByNid(issuer, kAdbKeyValueNid);
+    // Check for supported versions
+    if (*buf == kAdbKeyIdentifierV0) {
+        return GetX509NameTextByNid(issuer, kAdbKeyValueNid);
+    }
+    return std::nullopt;
 }
 
 std::string SHA256BitsToHexString(std::string_view sha256) {
     CHECK_EQ(sha256.size(), static_cast<size_t>(SHA256_DIGEST_LENGTH));
     std::stringstream ss;
+    auto* u8 = reinterpret_cast<const uint8_t*>(sha256.data());
     ss << std::uppercase << std::setfill('0') << std::hex;
     // Convert to hex-string representation
     for (size_t i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
-        ss << std::setw(2) << (0x00FF & sha256[i]);
+        // Need to cast to something bigger than one byte, or
+        // stringstream will interpret it as a char value.
+        ss << std::setw(2) << static_cast<uint16_t>(u8[i]);
     }
     return ss.str();
 }
 
 std::optional<std::string> SHA256HexStringToBits(std::string_view sha256_str) {
     if (sha256_str.size() != SHA256_DIGEST_LENGTH * 2) {
-        return {};
+        return std::nullopt;
     }
 
     std::string result;
@@ -119,7 +125,7 @@
         auto bytestr = std::string(sha256_str.substr(i * 2, 2));
         if (!IsHexDigit(bytestr[0]) || !IsHexDigit(bytestr[1])) {
             LOG(ERROR) << "SHA256 string has invalid non-hex chars";
-            return {};
+            return std::nullopt;
         }
         result += static_cast<char>(std::stol(bytestr, nullptr, 16));
     }
diff --git a/adb/tls/include/adb/tls/tls_connection.h b/adb/tls/include/adb/tls/tls_connection.h
index ae70857..bc5b98a 100644
--- a/adb/tls/include/adb/tls/tls_connection.h
+++ b/adb/tls/include/adb/tls/tls_connection.h
@@ -55,16 +55,15 @@
 
     // Adds a trusted certificate to the list for the SSL connection.
     // During the handshake phase, it will check the list of trusted certificates.
-    // The connection will fail if the peer's certificate is not in the list. Use
-    // |EnableCertificateVerification(false)| to disable certificate
-    // verification.
+    // The connection will fail if the peer's certificate is not in the list. If
+    // you would like to accept any certificate, use #SetCertVerifyCallback and
+    // set your callback to always return 1.
     //
     // Returns true if |cert| was successfully added, false otherwise.
     virtual bool AddTrustedCertificate(std::string_view cert) = 0;
 
     // Sets a custom certificate verify callback. |cb| must return 1 if the
-    // certificate is trusted. Otherwise, return 0 if not. Note that |cb| is
-    // only used if EnableCertificateVerification(false).
+    // certificate is trusted. Otherwise, return 0 if not.
     virtual void SetCertVerifyCallback(CertVerifyCb cb) = 0;
 
     // Configures a client |ca_list| that the server sends to the client in the
diff --git a/adb/tls/tests/tls_connection_test.cpp b/adb/tls/tests/tls_connection_test.cpp
index 880904b..27bc1c9 100644
--- a/adb/tls/tests/tls_connection_test.cpp
+++ b/adb/tls/tests/tls_connection_test.cpp
@@ -199,24 +199,10 @@
 static std::vector<CAIssuer> kCAIssuers = {
         {
                 {NID_commonName, {'a', 'b', 'c', 'd', 'e'}},
-                {NID_organizationName,
-                 {
-                         'd',
-                         'e',
-                         'f',
-                         'g',
-                 }},
+                {NID_organizationName, {'d', 'e', 'f', 'g'}},
         },
         {
-                {NID_commonName,
-                 {
-                         'h',
-                         'i',
-                         'j',
-                         'k',
-                         'l',
-                         'm',
-                 }},
+                {NID_commonName, {'h', 'i', 'j', 'k', 'l', 'm'}},
                 {NID_countryName, {'n', 'o'}},
         },
 };
@@ -224,8 +210,6 @@
 class AdbWifiTlsConnectionTest : public testing::Test {
   protected:
     virtual void SetUp() override {
-        // TODO: move client code in each test into its own thread, as the
-        // socket pair buffer is limited.
         android::base::Socketpair(SOCK_STREAM, &server_fd_, &client_fd_);
         server_ = TlsConnection::Create(TlsConnection::Role::Server, kTestRsa2048ServerCert,
                                         kTestRsa2048ServerPrivKey, server_fd_);
@@ -257,14 +241,8 @@
         return ret;
     }
 
-    void StartClientHandshakeAsync(bool expect_success) {
-        client_thread_ = std::thread([=]() {
-            if (expect_success) {
-                EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
-            } else {
-                EXPECT_NE(client_->DoHandshake(), TlsError::Success);
-            }
-        });
+    void StartClientHandshakeAsync(TlsError expected) {
+        client_thread_ = std::thread([=]() { EXPECT_EQ(client_->DoHandshake(), expected); });
     }
 
     void WaitForClientConnection() {
@@ -313,45 +291,52 @@
     // Allow any certificate
     server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
     client_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
-    StartClientHandshakeAsync(true);
+    StartClientHandshakeAsync(TlsError::Success);
 
     // Handshake should succeed
-    EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
     WaitForClientConnection();
 
-    // Client write, server read
-    EXPECT_TRUE(client_->WriteFully(
-            std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+    // Test client/server read and writes
+    client_thread_ = std::thread([&]() {
+        EXPECT_TRUE(client_->WriteFully(
+                std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+        // Try with overloaded ReadFully
+        std::vector<uint8_t> buf(msg_.size());
+        ASSERT_TRUE(client_->ReadFully(buf.data(), msg_.size()));
+        EXPECT_EQ(buf, msg_);
+    });
+
     auto data = server_->ReadFully(msg_.size());
     EXPECT_EQ(data, msg_);
-
-    // Client read, server write
     EXPECT_TRUE(server_->WriteFully(
             std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
-    // Try with overloaded ReadFully
-    std::vector<uint8_t> buf(msg_.size());
-    ASSERT_TRUE(client_->ReadFully(buf.data(), msg_.size()));
-    EXPECT_EQ(buf, msg_);
+
+    WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, NoTrustedCertificates) {
-    StartClientHandshakeAsync(false);
+    StartClientHandshakeAsync(TlsError::CertificateRejected);
 
     // Handshake should not succeed
-    EXPECT_NE(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
     WaitForClientConnection();
 
-    // Client write, server read should fail
-    EXPECT_FALSE(client_->WriteFully(
-            std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+    // All writes and reads should fail
+    client_thread_ = std::thread([&]() {
+        // Client write, server read should fail
+        EXPECT_FALSE(client_->WriteFully(
+                std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+        auto data = client_->ReadFully(msg_.size());
+        EXPECT_EQ(data.size(), 0);
+    });
+
     auto data = server_->ReadFully(msg_.size());
     EXPECT_EQ(data.size(), 0);
-
-    // Client read, server write should fail
     EXPECT_FALSE(server_->WriteFully(
             std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
-    data = client_->ReadFully(msg_.size());
-    EXPECT_EQ(data.size(), 0);
+
+    WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates) {
@@ -359,23 +344,26 @@
     EXPECT_TRUE(client_->AddTrustedCertificate(kTestRsa2048ServerCert));
     EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048ClientCert));
 
-    StartClientHandshakeAsync(true);
+    StartClientHandshakeAsync(TlsError::Success);
 
     // Handshake should succeed
-    EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
     WaitForClientConnection();
 
-    // Client write, server read
-    EXPECT_TRUE(client_->WriteFully(
-            std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+    // All read writes should succeed
+    client_thread_ = std::thread([&]() {
+        EXPECT_TRUE(client_->WriteFully(
+                std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+        auto data = client_->ReadFully(msg_.size());
+        EXPECT_EQ(data, msg_);
+    });
+
     auto data = server_->ReadFully(msg_.size());
     EXPECT_EQ(data, msg_);
-
-    // Client read, server write
     EXPECT_TRUE(server_->WriteFully(
             std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
-    data = client_->ReadFully(msg_.size());
-    EXPECT_EQ(data, msg_);
+
+    WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, AddTrustedCertificates_ClientWrongCert) {
@@ -387,23 +375,26 @@
     // Without enabling EnableClientPostHandshakeCheck(), DoHandshake() will
     // succeed, because in TLS 1.3, the client doesn't get notified if the
     // server rejected the certificate until a read operation is called.
-    StartClientHandshakeAsync(true);
+    StartClientHandshakeAsync(TlsError::Success);
 
     // Handshake should fail for server, succeed for client
-    EXPECT_NE(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
     WaitForClientConnection();
 
-    // Client write succeeds, server read should fail
-    EXPECT_TRUE(client_->WriteFully(
-            std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+    // Client writes will succeed, everything else will fail.
+    client_thread_ = std::thread([&]() {
+        EXPECT_TRUE(client_->WriteFully(
+                std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+        auto data = client_->ReadFully(msg_.size());
+        EXPECT_EQ(data.size(), 0);
+    });
+
     auto data = server_->ReadFully(msg_.size());
     EXPECT_EQ(data.size(), 0);
-
-    // Client read, server write should fail
     EXPECT_FALSE(server_->WriteFully(
             std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
-    data = client_->ReadFully(msg_.size());
-    EXPECT_EQ(data.size(), 0);
+
+    WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, ExportKeyingMaterial) {
@@ -415,10 +406,10 @@
     EXPECT_TRUE(client_->AddTrustedCertificate(kTestRsa2048ServerCert));
     EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048ClientCert));
 
-    StartClientHandshakeAsync(true);
+    StartClientHandshakeAsync(TlsError::Success);
 
     // Handshake should succeed
-    EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
     WaitForClientConnection();
 
     // Verify the client and server's exported key material match.
@@ -439,10 +430,10 @@
     // Client handshake should succeed, because in TLS 1.3, client does not
     // realize that the peer rejected the certificate until after a read
     // operation.
-    client_thread_ = std::thread([&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::Success); });
+    StartClientHandshakeAsync(TlsError::Success);
 
     // Server handshake should fail
-    EXPECT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
     WaitForClientConnection();
 }
 
@@ -455,11 +446,10 @@
     server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 0; });
 
     // Client handshake should fail because server rejects everything
-    client_thread_ = std::thread(
-            [&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::PeerRejectedCertificate); });
+    StartClientHandshakeAsync(TlsError::PeerRejectedCertificate);
 
     // Server handshake should fail
-    EXPECT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
     WaitForClientConnection();
 }
 
@@ -469,11 +459,10 @@
     // Server accepts all
     server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
     // Client handshake should fail
-    client_thread_ = std::thread(
-            [&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::CertificateRejected); });
+    StartClientHandshakeAsync(TlsError::CertificateRejected);
 
     // Server handshake should fail
-    EXPECT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
     WaitForClientConnection();
 }
 
@@ -488,15 +477,15 @@
     server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
 
     // Client handshake should fail
-    client_thread_ = std::thread(
-            [&]() { EXPECT_EQ(client_->DoHandshake(), TlsError::CertificateRejected); });
+    StartClientHandshakeAsync(TlsError::CertificateRejected);
 
     // Server handshake should fail
-    EXPECT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::PeerRejectedCertificate);
     WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, EnableClientPostHandshakeCheck_ClientWrongCert) {
+    client_->AddTrustedCertificate(kTestRsa2048ServerCert);
     // client's DoHandshake() will fail if the server rejected the certificate
     client_->EnableClientPostHandshakeCheck(true);
 
@@ -504,23 +493,26 @@
     EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048UnknownCert));
 
     // Handshake should fail for client
-    StartClientHandshakeAsync(false);
+    StartClientHandshakeAsync(TlsError::PeerRejectedCertificate);
 
     // Handshake should fail for server
-    EXPECT_NE(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::CertificateRejected);
     WaitForClientConnection();
 
-    // Client write fails, server read should fail
-    EXPECT_FALSE(client_->WriteFully(
-            std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+    // All read writes should fail
+    client_thread_ = std::thread([&]() {
+        EXPECT_FALSE(client_->WriteFully(
+                std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
+        auto data = client_->ReadFully(msg_.size());
+        EXPECT_EQ(data.size(), 0);
+    });
+
     auto data = server_->ReadFully(msg_.size());
     EXPECT_EQ(data.size(), 0);
-
-    // Client read, server write should fail
     EXPECT_FALSE(server_->WriteFully(
             std::string_view(reinterpret_cast<const char*>(msg_.data()), msg_.size())));
-    data = client_->ReadFully(msg_.size());
-    EXPECT_EQ(data.size(), 0);
+
+    WaitForClientConnection();
 }
 
 TEST_F(AdbWifiTlsConnectionTest, SetClientCAList_Empty) {
@@ -569,12 +561,12 @@
             return 1;
         });
         // Client handshake should succeed
-        EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
+        ASSERT_EQ(client_->DoHandshake(), TlsError::Success);
     });
 
     EXPECT_TRUE(server_->AddTrustedCertificate(kTestRsa2048UnknownCert));
     // Server handshake should succeed
-    EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
     client_thread_.join();
 }
 
@@ -604,12 +596,12 @@
             return 1;
         });
         // Client handshake should succeed
-        EXPECT_EQ(client_->DoHandshake(), TlsError::Success);
+        ASSERT_EQ(client_->DoHandshake(), TlsError::Success);
     });
 
     server_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
     // Server handshake should succeed
-    EXPECT_EQ(server_->DoHandshake(), TlsError::Success);
+    ASSERT_EQ(server_->DoHandshake(), TlsError::Success);
     client_thread_.join();
 }
 }  // namespace tls
diff --git a/adb/tls/tls_connection.cpp b/adb/tls/tls_connection.cpp
index 7df6ef4..853cdac 100644
--- a/adb/tls/tls_connection.cpp
+++ b/adb/tls/tls_connection.cpp
@@ -61,6 +61,7 @@
     static const char* SSLErrorString();
     void Invalidate();
     TlsError GetFailureReason(int err);
+    const char* RoleToString() { return role_ == Role::Server ? kServerRoleStr : kClientRoleStr; }
 
     Role role_;
     bssl::UniquePtr<EVP_PKEY> priv_key_;
@@ -75,15 +76,19 @@
     CertVerifyCb cert_verify_cb_;
     SetCertCb set_cert_cb_;
     borrowed_fd fd_;
+    static constexpr char kClientRoleStr[] = "[client]: ";
+    static constexpr char kServerRoleStr[] = "[server]: ";
 };  // TlsConnectionImpl
 
 TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
                                      borrowed_fd fd)
     : role_(role), fd_(fd) {
     CHECK(!cert.empty() && !priv_key.empty());
-    LOG(INFO) << "Initializing adbwifi TlsConnection";
+    LOG(INFO) << RoleToString() << "Initializing adbwifi TlsConnection";
     cert_ = BufferFromPEM(cert);
+    CHECK(cert_);
     priv_key_ = EvpPkeyFromPEM(priv_key);
+    CHECK(priv_key_);
 }
 
 TlsConnectionImpl::~TlsConnectionImpl() {
@@ -149,7 +154,7 @@
     // Create X509 buffer from the certificate string
     auto buf = X509FromBuffer(BufferFromPEM(cert));
     if (buf == nullptr) {
-        LOG(ERROR) << "Failed to create a X509 buffer for the certificate.";
+        LOG(ERROR) << RoleToString() << "Failed to create a X509 buffer for the certificate.";
         return false;
     }
     known_certificates_.push_back(std::move(buf));
@@ -205,8 +210,7 @@
 }
 
 TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
-    int err = -1;
-    LOG(INFO) << "Starting adbwifi tls handshake";
+    LOG(INFO) << RoleToString() << "Starting adbwifi tls handshake";
     ssl_ctx_.reset(SSL_CTX_new(TLS_method()));
     // TODO: Remove set_max_proto_version() once external/boringssl is updated
     // past
@@ -214,14 +218,14 @@
     if (ssl_ctx_.get() == nullptr ||
         !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) ||
         !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) {
-        LOG(ERROR) << "Failed to create SSL context";
+        LOG(ERROR) << RoleToString() << "Failed to create SSL context";
         return TlsError::UnknownFailure;
     }
 
     // Register user-supplied known certificates
     for (auto const& cert : known_certificates_) {
         if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) {
-            LOG(ERROR) << "Unable to add certificates into the X509_STORE";
+            LOG(ERROR) << RoleToString() << "Unable to add certificates into the X509_STORE";
             return TlsError::UnknownFailure;
         }
     }
@@ -248,7 +252,8 @@
     };
     if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(),
                                    priv_key_.get(), nullptr)) {
-        LOG(ERROR) << "Unable to register the certificate chain file and private key ["
+        LOG(ERROR) << RoleToString()
+                   << "Unable to register the certificate chain file and private key ["
                    << SSLErrorString() << "]";
         Invalidate();
         return TlsError::UnknownFailure;
@@ -259,19 +264,21 @@
     // Okay! Let's try to do the handshake!
     ssl_.reset(SSL_new(ssl_ctx_.get()));
     if (!SSL_set_fd(ssl_.get(), fd_.get())) {
-        LOG(ERROR) << "SSL_set_fd failed. [" << SSLErrorString() << "]";
+        LOG(ERROR) << RoleToString() << "SSL_set_fd failed. [" << SSLErrorString() << "]";
         return TlsError::UnknownFailure;
     }
+
     switch (role_) {
         case Role::Server:
-            err = SSL_accept(ssl_.get());
+            SSL_set_accept_state(ssl_.get());
             break;
         case Role::Client:
-            err = SSL_connect(ssl_.get());
+            SSL_set_connect_state(ssl_.get());
             break;
     }
-    if (err != 1) {
-        LOG(ERROR) << "Handshake failed in SSL_accept/SSL_connect [" << SSLErrorString() << "]";
+    if (SSL_do_handshake(ssl_.get()) != 1) {
+        LOG(ERROR) << RoleToString() << "Handshake failed in SSL_accept/SSL_connect ["
+                   << SSLErrorString() << "]";
         auto sslerr = ERR_get_error();
         Invalidate();
         return GetFailureReason(sslerr);
@@ -281,16 +288,16 @@
         uint8_t check;
         // Try to peek one byte for any failures. This assumes on success that
         // the server actually sends something.
-        err = SSL_peek(ssl_.get(), &check, 1);
-        if (err <= 0) {
-            LOG(ERROR) << "Post-handshake SSL_peek failed [" << SSLErrorString() << "]";
+        if (SSL_peek(ssl_.get(), &check, 1) <= 0) {
+            LOG(ERROR) << RoleToString() << "Post-handshake SSL_peek failed [" << SSLErrorString()
+                       << "]";
             auto sslerr = ERR_get_error();
             Invalidate();
             return GetFailureReason(sslerr);
         }
     }
 
-    LOG(INFO) << "Handshake succeeded.";
+    LOG(INFO) << RoleToString() << "Handshake succeeded.";
     return TlsError::Success;
 }
 
@@ -311,7 +318,7 @@
 bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
     CHECK_GT(size, 0U);
     if (!ssl_) {
-        LOG(ERROR) << "Tried to read on a null SSL connection";
+        LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
         return false;
     }
 
@@ -321,7 +328,7 @@
         int bytes_read =
                 SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size));
         if (bytes_read <= 0) {
-            LOG(WARNING) << "SSL_read failed [" << SSLErrorString() << "]";
+            LOG(ERROR) << RoleToString() << "SSL_read failed [" << SSLErrorString() << "]";
             return false;
         }
         size -= bytes_read;
@@ -333,7 +340,7 @@
 bool TlsConnectionImpl::WriteFully(std::string_view data) {
     CHECK(!data.empty());
     if (!ssl_) {
-        LOG(ERROR) << "Tried to read on a null SSL connection";
+        LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
         return false;
     }
 
@@ -341,7 +348,7 @@
         int bytes_out = SSL_write(ssl_.get(), data.data(),
                                   std::min(static_cast<size_t>(INT_MAX), data.size()));
         if (bytes_out <= 0) {
-            LOG(WARNING) << "SSL_write failed [" << SSLErrorString() << "]";
+            LOG(ERROR) << RoleToString() << "SSL_write failed [" << SSLErrorString() << "]";
             return false;
         }
         data = data.substr(bytes_out);