You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
8.8 KiB
209 lines
8.8 KiB
4 months ago
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef _DNS_DNSTLSSOCKET_H
|
||
|
#define _DNS_DNSTLSSOCKET_H
|
||
|
|
||
|
#include <openssl/ssl.h>
|
||
|
#include <future>
|
||
|
#include <mutex>
|
||
|
|
||
|
#include <android-base/thread_annotations.h>
|
||
|
#include <android-base/unique_fd.h>
|
||
|
#include <netdutils/Slice.h>
|
||
|
#include <netdutils/Status.h>
|
||
|
|
||
|
#include "DnsTlsServer.h"
|
||
|
#include "IDnsTlsSocket.h"
|
||
|
#include "LockedQueue.h"
|
||
|
|
||
|
namespace android {
|
||
|
namespace net {
|
||
|
|
||
|
class IDnsTlsSocketObserver;
|
||
|
class DnsTlsSessionCache;
|
||
|
|
||
|
// A class for managing a TLS socket that sends and receives messages in
|
||
|
// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
|
||
|
// This class is not aware of query-response pairing or anything else about DNS.
|
||
|
// For the observer:
|
||
|
// This class is not re-entrant: the observer is not permitted to wait for a call to query()
|
||
|
// or the destructor in a callback. Doing so will result in deadlocks.
|
||
|
// This class may call the observer at any time after initialize(), until the destructor
|
||
|
// returns (but not after).
|
||
|
//
|
||
|
// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
|
||
|
//
|
||
|
// UNINITIALIZED
|
||
|
// |
|
||
|
// v
|
||
|
// INITIALIZED
|
||
|
// |
|
||
|
// v
|
||
|
// +----CONNECTING------+
|
||
|
// Handshake fails | | Handshake succeeds
|
||
|
// (onClose() when | |
|
||
|
// mAsyncHandshake is set) | v
|
||
|
// | +---> CONNECTED --+
|
||
|
// | | | |
|
||
|
// | +-----------+ | Idle timeout
|
||
|
// | Send/Recv queries | onClose()
|
||
|
// | onResponse() |
|
||
|
// | |
|
||
|
// | |
|
||
|
// +--> WAIT_FOR_DELETE <-----+
|
||
|
//
|
||
|
//
|
||
|
// TODO: Add onHandshakeFinished() for handshake results.
|
||
|
class DnsTlsSocket : public IDnsTlsSocket {
|
||
|
public:
|
||
|
enum class State {
|
||
|
UNINITIALIZED,
|
||
|
INITIALIZED,
|
||
|
CONNECTING,
|
||
|
CONNECTED,
|
||
|
WAIT_FOR_DELETE,
|
||
|
};
|
||
|
|
||
|
DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
|
||
|
IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
|
||
|
: mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
|
||
|
~DnsTlsSocket();
|
||
|
|
||
|
// Creates the SSL context for this session. Returns false on failure.
|
||
|
// This method should be called after construction and before use of a DnsTlsSocket.
|
||
|
// Only call this method once per DnsTlsSocket.
|
||
|
bool initialize() EXCLUDES(mLock);
|
||
|
|
||
|
// If async handshake is enabled, this function simply signals a handshake request, and the
|
||
|
// handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
|
||
|
// this function performs the handshake and returns after the handshake finishes.
|
||
|
bool startHandshake() EXCLUDES(mLock);
|
||
|
|
||
|
// Send a query on the provided SSL socket. |query| contains
|
||
|
// the body of a query, not including the ID header. This function will typically return before
|
||
|
// the query is actually sent. If this function fails, DnsTlsSocketObserver will be
|
||
|
// notified that the socket is closed.
|
||
|
// Note that success here indicates successful sending, not receipt of a response.
|
||
|
// Thread-safe.
|
||
|
bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
|
||
|
|
||
|
private:
|
||
|
// Lock to be held by the SSL event loop thread. This is not normally in contention.
|
||
|
std::mutex mLock;
|
||
|
|
||
|
// Forwards queries and receives responses. Blocks until the idle timeout.
|
||
|
void loop() EXCLUDES(mLock);
|
||
|
std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
|
||
|
|
||
|
// On success, sets mSslFd to a socket connected to mAddr (the
|
||
|
// connection will likely be in progress if mProtocol is IPPROTO_TCP).
|
||
|
// On error, returns the errno.
|
||
|
netdutils::Status tcpConnect() REQUIRES(mLock);
|
||
|
|
||
|
bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);
|
||
|
|
||
|
// Connect an SSL session on the provided socket. If connection fails, closing the
|
||
|
// socket remains the caller's responsibility.
|
||
|
bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
|
||
|
|
||
|
// Connect an SSL session on the provided socket. This is an interruptible version
|
||
|
// which allows to terminate connection handshake any time.
|
||
|
bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);
|
||
|
|
||
|
// Disconnect the SSL session and close the socket.
|
||
|
void sslDisconnect() REQUIRES(mLock);
|
||
|
|
||
|
// Writes a buffer to the socket.
|
||
|
bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
|
||
|
|
||
|
// Reads exactly the specified number of bytes from the socket, or fails.
|
||
|
// Returns SSL_ERROR_NONE on success.
|
||
|
// If |wait| is true, then this function always blocks. Otherwise, it
|
||
|
// will return SSL_ERROR_WANT_READ if there is no data from the server to read.
|
||
|
int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
|
||
|
|
||
|
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
|
||
|
|
||
|
// Read one DNS response. It can potentially block until reading the exact bytes of
|
||
|
// the response.
|
||
|
bool readResponse() REQUIRES(mLock);
|
||
|
|
||
|
// It is only used for DNS-OVER-TLS internal test.
|
||
|
bool setTestCaCertificate() REQUIRES(mLock);
|
||
|
|
||
|
// Similar to query(), this function uses incrementEventFd to send a message to the
|
||
|
// loop thread. However, instead of incrementing the counter by one (indicating a
|
||
|
// new query), it wraps the counter to negative, which we use to indicate a shutdown
|
||
|
// request.
|
||
|
void requestLoopShutdown() EXCLUDES(mLock);
|
||
|
|
||
|
// This function sends a message to the loop thread by incrementing mEventFd.
|
||
|
bool incrementEventFd(int64_t count) EXCLUDES(mLock);
|
||
|
|
||
|
// Transition the state from expected state |from| to new state |to|.
|
||
|
void transitionState(State from, State to) REQUIRES(mLock);
|
||
|
|
||
|
// Queue of pending queries. query() pushes items onto the queue and notifies
|
||
|
// the loop thread by incrementing mEventFd. loop() reads items off the queue.
|
||
|
LockedQueue<std::vector<uint8_t>> mQueue;
|
||
|
|
||
|
// eventfd socket used for notifying the SSL thread when queries are ready to send.
|
||
|
// This socket acts similarly to an atomic counter, incremented by query() and cleared
|
||
|
// by loop(). We have to use a socket because the SSL thread needs to wait in poll()
|
||
|
// for input from either a remote server or a query thread. Since eventfd does not have
|
||
|
// EOF, we indicate a close request by setting the counter to a negative number.
|
||
|
// This file descriptor is opened by initialize(), and closed implicitly after
|
||
|
// destruction.
|
||
|
// Note that: data starts being read from the eventfd when the state is CONNECTED.
|
||
|
base::unique_fd mEventFd;
|
||
|
|
||
|
// An eventfd used to listen to shutdown requests when the state is CONNECTING.
|
||
|
// TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
|
||
|
// handle shutdown requests.
|
||
|
base::unique_fd mShutdownEvent;
|
||
|
|
||
|
// SSL Socket fields.
|
||
|
bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
|
||
|
base::unique_fd mSslFd GUARDED_BY(mLock);
|
||
|
bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
|
||
|
static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
|
||
|
|
||
|
const unsigned mMark; // Socket mark
|
||
|
const DnsTlsServer mServer;
|
||
|
IDnsTlsSocketObserver* _Nonnull const mObserver;
|
||
|
DnsTlsSessionCache* _Nonnull const mCache;
|
||
|
State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
|
||
|
|
||
|
// If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
|
||
|
// thread (the call to startHandshake()).
|
||
|
bool mAsyncHandshake GUARDED_BY(mLock) = false;
|
||
|
|
||
|
// The time to wait for the attempt on connecting to the server.
|
||
|
// Set the default value 127 seconds to be consistent with TCP connect timeout.
|
||
|
// (presume net.ipv4.tcp_syn_retries = 6)
|
||
|
static constexpr int kDotConnectTimeoutMs = 127 * 1000;
|
||
|
int mConnectTimeoutMs;
|
||
|
|
||
|
// For testing.
|
||
|
friend class DnsTlsSocketTest;
|
||
|
};
|
||
|
|
||
|
} // end of namespace net
|
||
|
} // end of namespace android
|
||
|
|
||
|
#endif // _DNS_DNSTLSSOCKET_H
|