blob: 2940500b4c253ede8006c0eae544d212a2eb656c [file] [log] [blame]
Ben Schwartzded1b702017-10-25 14:41:02 -04001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef _DNS_DNSTLSSOCKET_H
18#define _DNS_DNSTLSSOCKET_H
19
Bernie Innocentiad4e26e2019-01-30 11:16:36 +090020#include <openssl/ssl.h>
Ben Schwartzded1b702017-10-25 14:41:02 -040021#include <future>
22#include <mutex>
Ben Schwartzded1b702017-10-25 14:41:02 -040023
24#include <android-base/thread_annotations.h>
25#include <android-base/unique_fd.h>
26#include <netdutils/Slice.h>
27#include <netdutils/Status.h>
28
Mike Yu5ae61542018-10-19 22:11:43 +080029#include "DnsTlsServer.h"
30#include "IDnsTlsSocket.h"
Ben Schwartzbfc8d992019-01-10 14:30:46 -050031#include "LockedQueue.h"
Ben Schwartzded1b702017-10-25 14:41:02 -040032
33namespace android {
34namespace net {
35
Ben Schwartz33860762017-10-25 14:41:02 -040036class IDnsTlsSocketObserver;
Ben Schwartzded1b702017-10-25 14:41:02 -040037class DnsTlsSessionCache;
38
Ben Schwartzded1b702017-10-25 14:41:02 -040039// A class for managing a TLS socket that sends and receives messages in
40// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
Ben Schwartz33860762017-10-25 14:41:02 -040041// This class is not aware of query-response pairing or anything else about DNS.
42// For the observer:
43// This class is not re-entrant: the observer is not permitted to wait for a call to query()
44// or the destructor in a callback. Doing so will result in deadlocks.
45// This class may call the observer at any time after initialize(), until the destructor
46// returns (but not after).
Mike Yua46fae72018-11-01 20:07:00 +080047class DnsTlsSocket : public IDnsTlsSocket {
48 public:
Ben Schwartzded1b702017-10-25 14:41:02 -040049 DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
Bernie Innocentiad4e26e2019-01-30 11:16:36 +090050 IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
51 : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
Ben Schwartzded1b702017-10-25 14:41:02 -040052 ~DnsTlsSocket();
53
54 // Creates the SSL context for this session and connect. Returns false on failure.
55 // This method should be called after construction and before use of a DnsTlsSocket.
56 // Only call this method once per DnsTlsSocket.
57 bool initialize() EXCLUDES(mLock);
58
59 // Send a query on the provided SSL socket. |query| contains
Ben Schwartz33860762017-10-25 14:41:02 -040060 // the body of a query, not including the ID header. This function will typically return before
61 // the query is actually sent. If this function fails, DnsTlsSocketObserver will be
62 // notified that the socket is closed.
63 // Note that success here indicates successful sending, not receipt of a response.
64 // Thread-safe.
Ben Schwartz18329ec2019-01-22 17:32:17 -050065 bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
Ben Schwartzded1b702017-10-25 14:41:02 -040066
Bernie Innocentiad4e26e2019-01-30 11:16:36 +090067 private:
Ben Schwartz33860762017-10-25 14:41:02 -040068 // Lock to be held by the SSL event loop thread. This is not normally in contention.
Ben Schwartzded1b702017-10-25 14:41:02 -040069 std::mutex mLock;
70
Ben Schwartz33860762017-10-25 14:41:02 -040071 // Forwards queries and receives responses. Blocks until the idle timeout.
72 void loop() EXCLUDES(mLock);
73 std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
74
Ben Schwartzded1b702017-10-25 14:41:02 -040075 // On success, sets mSslFd to a socket connected to mAddr (the
76 // connection will likely be in progress if mProtocol is IPPROTO_TCP).
77 // On error, returns the errno.
78 netdutils::Status tcpConnect() REQUIRES(mLock);
79
80 // Connect an SSL session on the provided socket. If connection fails, closing the
81 // socket remains the caller's responsibility.
82 bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
83
84 // Disconnect the SSL session and close the socket.
85 void sslDisconnect() REQUIRES(mLock);
86
87 // Writes a buffer to the socket.
Bernie Innocentiad4e26e2019-01-30 11:16:36 +090088 bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
Ben Schwartzded1b702017-10-25 14:41:02 -040089
Ben Schwartz33860762017-10-25 14:41:02 -040090 // Reads exactly the specified number of bytes from the socket, or fails.
91 // Returns SSL_ERROR_NONE on success.
92 // If |wait| is true, then this function always blocks. Otherwise, it
93 // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
Bernie Innocentiad4e26e2019-01-30 11:16:36 +090094 int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
Ben Schwartzded1b702017-10-25 14:41:02 -040095
Ben Schwartzbfc8d992019-01-10 14:30:46 -050096 bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -040097 bool readResponse() REQUIRES(mLock);
98
Ben Schwartz18329ec2019-01-22 17:32:17 -050099 // Similar to query(), this function uses incrementEventFd to send a message to the
100 // loop thread. However, instead of incrementing the counter by one (indicating a
101 // new query), it wraps the counter to negative, which we use to indicate a shutdown
102 // request.
103 void requestLoopShutdown() EXCLUDES(mLock);
104
105 // This function sends a message to the loop thread by incrementing mEventFd.
106 bool incrementEventFd(int64_t count) EXCLUDES(mLock);
107
Ben Schwartzbfc8d992019-01-10 14:30:46 -0500108 // Queue of pending queries. query() pushes items onto the queue and notifies
109 // the loop thread by incrementing mEventFd. loop() reads items off the queue.
110 LockedQueue<std::vector<uint8_t>> mQueue;
111
112 // eventfd socket used for notifying the SSL thread when queries are ready to send.
113 // This socket acts similarly to an atomic counter, incremented by query() and cleared
114 // by loop(). We have to use a socket because the SSL thread needs to wait in poll()
Ben Schwartz18329ec2019-01-22 17:32:17 -0500115 // for input from either a remote server or a query thread. Since eventfd does not have
116 // EOF, we indicate a close request by setting the counter to a negative number.
117 // This file descriptor is opened by initialize(), and closed implicitly after
118 // destruction.
Ben Schwartzbfc8d992019-01-10 14:30:46 -0500119 base::unique_fd mEventFd;
Ben Schwartzded1b702017-10-25 14:41:02 -0400120
121 // SSL Socket fields.
122 bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
123 base::unique_fd mSslFd GUARDED_BY(mLock);
124 bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
Ben Schwartz33860762017-10-25 14:41:02 -0400125 static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
Ben Schwartzded1b702017-10-25 14:41:02 -0400126
127 const unsigned mMark; // Socket mark
128 const DnsTlsServer mServer;
Ben Schwartz33860762017-10-25 14:41:02 -0400129 IDnsTlsSocketObserver* _Nonnull const mObserver;
Ben Schwartzded1b702017-10-25 14:41:02 -0400130 DnsTlsSessionCache* _Nonnull const mCache;
131};
132
133} // end of namespace net
134} // end of namespace android
135
136#endif // _DNS_DNSTLSSOCKET_H