Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2018 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #ifndef _DNS_DNSTLSSOCKET_H |
| 18 | #define _DNS_DNSTLSSOCKET_H |
| 19 | |
Bernie Innocenti | ad4e26e | 2019-01-30 11:16:36 +0900 | [diff] [blame] | 20 | #include <openssl/ssl.h> |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 21 | #include <future> |
| 22 | #include <mutex> |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 23 | |
| 24 | #include <android-base/thread_annotations.h> |
| 25 | #include <android-base/unique_fd.h> |
| 26 | #include <netdutils/Slice.h> |
| 27 | #include <netdutils/Status.h> |
| 28 | |
Mike Yu | 5ae6154 | 2018-10-19 22:11:43 +0800 | [diff] [blame] | 29 | #include "DnsTlsServer.h" |
| 30 | #include "IDnsTlsSocket.h" |
Ben Schwartz | bfc8d99 | 2019-01-10 14:30:46 -0500 | [diff] [blame] | 31 | #include "LockedQueue.h" |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 32 | |
| 33 | namespace android { |
| 34 | namespace net { |
| 35 | |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 36 | class IDnsTlsSocketObserver; |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 37 | class DnsTlsSessionCache; |
| 38 | |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 39 | // A class for managing a TLS socket that sends and receives messages in |
| 40 | // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format). |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 41 | // This class is not aware of query-response pairing or anything else about DNS. |
| 42 | // For the observer: |
| 43 | // This class is not re-entrant: the observer is not permitted to wait for a call to query() |
| 44 | // or the destructor in a callback. Doing so will result in deadlocks. |
| 45 | // This class may call the observer at any time after initialize(), until the destructor |
| 46 | // returns (but not after). |
Mike Yu | a46fae7 | 2018-11-01 20:07:00 +0800 | [diff] [blame] | 47 | class DnsTlsSocket : public IDnsTlsSocket { |
| 48 | public: |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 49 | DnsTlsSocket(const DnsTlsServer& server, unsigned mark, |
Bernie Innocenti | ad4e26e | 2019-01-30 11:16:36 +0900 | [diff] [blame] | 50 | IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache) |
| 51 | : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 52 | ~DnsTlsSocket(); |
| 53 | |
| 54 | // Creates the SSL context for this session and connect. Returns false on failure. |
| 55 | // This method should be called after construction and before use of a DnsTlsSocket. |
| 56 | // Only call this method once per DnsTlsSocket. |
| 57 | bool initialize() EXCLUDES(mLock); |
| 58 | |
| 59 | // Send a query on the provided SSL socket. |query| contains |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 60 | // the body of a query, not including the ID header. This function will typically return before |
| 61 | // the query is actually sent. If this function fails, DnsTlsSocketObserver will be |
| 62 | // notified that the socket is closed. |
| 63 | // Note that success here indicates successful sending, not receipt of a response. |
| 64 | // Thread-safe. |
Ben Schwartz | 18329ec | 2019-01-22 17:32:17 -0500 | [diff] [blame] | 65 | bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock); |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 66 | |
Bernie Innocenti | ad4e26e | 2019-01-30 11:16:36 +0900 | [diff] [blame] | 67 | private: |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 68 | // Lock to be held by the SSL event loop thread. This is not normally in contention. |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 69 | std::mutex mLock; |
| 70 | |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 71 | // Forwards queries and receives responses. Blocks until the idle timeout. |
| 72 | void loop() EXCLUDES(mLock); |
| 73 | std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock); |
| 74 | |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 75 | // On success, sets mSslFd to a socket connected to mAddr (the |
| 76 | // connection will likely be in progress if mProtocol is IPPROTO_TCP). |
| 77 | // On error, returns the errno. |
| 78 | netdutils::Status tcpConnect() REQUIRES(mLock); |
| 79 | |
| 80 | // Connect an SSL session on the provided socket. If connection fails, closing the |
| 81 | // socket remains the caller's responsibility. |
| 82 | bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock); |
| 83 | |
| 84 | // Disconnect the SSL session and close the socket. |
| 85 | void sslDisconnect() REQUIRES(mLock); |
| 86 | |
| 87 | // Writes a buffer to the socket. |
Bernie Innocenti | ad4e26e | 2019-01-30 11:16:36 +0900 | [diff] [blame] | 88 | bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock); |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 89 | |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 90 | // Reads exactly the specified number of bytes from the socket, or fails. |
| 91 | // Returns SSL_ERROR_NONE on success. |
| 92 | // If |wait| is true, then this function always blocks. Otherwise, it |
| 93 | // will return SSL_ERROR_WANT_READ if there is no data from the server to read. |
Bernie Innocenti | ad4e26e | 2019-01-30 11:16:36 +0900 | [diff] [blame] | 94 | int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock); |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 95 | |
Ben Schwartz | bfc8d99 | 2019-01-10 14:30:46 -0500 | [diff] [blame] | 96 | bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 97 | bool readResponse() REQUIRES(mLock); |
| 98 | |
Ben Schwartz | 18329ec | 2019-01-22 17:32:17 -0500 | [diff] [blame] | 99 | // Similar to query(), this function uses incrementEventFd to send a message to the |
| 100 | // loop thread. However, instead of incrementing the counter by one (indicating a |
| 101 | // new query), it wraps the counter to negative, which we use to indicate a shutdown |
| 102 | // request. |
| 103 | void requestLoopShutdown() EXCLUDES(mLock); |
| 104 | |
| 105 | // This function sends a message to the loop thread by incrementing mEventFd. |
| 106 | bool incrementEventFd(int64_t count) EXCLUDES(mLock); |
| 107 | |
Ben Schwartz | bfc8d99 | 2019-01-10 14:30:46 -0500 | [diff] [blame] | 108 | // Queue of pending queries. query() pushes items onto the queue and notifies |
| 109 | // the loop thread by incrementing mEventFd. loop() reads items off the queue. |
| 110 | LockedQueue<std::vector<uint8_t>> mQueue; |
| 111 | |
| 112 | // eventfd socket used for notifying the SSL thread when queries are ready to send. |
| 113 | // This socket acts similarly to an atomic counter, incremented by query() and cleared |
| 114 | // by loop(). We have to use a socket because the SSL thread needs to wait in poll() |
Ben Schwartz | 18329ec | 2019-01-22 17:32:17 -0500 | [diff] [blame] | 115 | // for input from either a remote server or a query thread. Since eventfd does not have |
| 116 | // EOF, we indicate a close request by setting the counter to a negative number. |
| 117 | // This file descriptor is opened by initialize(), and closed implicitly after |
| 118 | // destruction. |
Ben Schwartz | bfc8d99 | 2019-01-10 14:30:46 -0500 | [diff] [blame] | 119 | base::unique_fd mEventFd; |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 120 | |
| 121 | // SSL Socket fields. |
| 122 | bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); |
| 123 | base::unique_fd mSslFd GUARDED_BY(mLock); |
| 124 | bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock); |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 125 | static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20); |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 126 | |
| 127 | const unsigned mMark; // Socket mark |
| 128 | const DnsTlsServer mServer; |
Ben Schwartz | 3386076 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 129 | IDnsTlsSocketObserver* _Nonnull const mObserver; |
Ben Schwartz | ded1b70 | 2017-10-25 14:41:02 -0400 | [diff] [blame] | 130 | DnsTlsSessionCache* _Nonnull const mCache; |
| 131 | }; |
| 132 | |
| 133 | } // end of namespace net |
| 134 | } // end of namespace android |
| 135 | |
| 136 | #endif // _DNS_DNSTLSSOCKET_H |