You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
552 lines
19 KiB
552 lines
19 KiB
// Copyright 2015 The Chromium OS Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include <brillo/streams/tls_stream.h>
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <openssl/err.h>
|
|
#include <openssl/ssl.h>
|
|
|
|
#include <base/bind.h>
|
|
#include <base/memory/weak_ptr.h>
|
|
#include <brillo/message_loops/message_loop.h>
|
|
#include <brillo/secure_blob.h>
|
|
#include <brillo/streams/openssl_stream_bio.h>
|
|
#include <brillo/streams/stream_utils.h>
|
|
#include <brillo/strings/string_utils.h>
|
|
|
|
namespace {
|
|
|
|
// SSL info callback which is called by OpenSSL when we enable logging level of
|
|
// at least 3. This logs the information about the internal TLS handshake.
|
|
void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
|
|
std::string reason;
|
|
std::vector<std::string> info;
|
|
if (where & SSL_CB_LOOP)
|
|
info.push_back("loop");
|
|
if (where & SSL_CB_EXIT)
|
|
info.push_back("exit");
|
|
if (where & SSL_CB_READ)
|
|
info.push_back("read");
|
|
if (where & SSL_CB_WRITE)
|
|
info.push_back("write");
|
|
if (where & SSL_CB_ALERT) {
|
|
info.push_back("alert");
|
|
reason = ", reason: ";
|
|
reason += SSL_alert_type_string_long(ret);
|
|
reason += "/";
|
|
reason += SSL_alert_desc_string_long(ret);
|
|
}
|
|
if (where & SSL_CB_HANDSHAKE_START)
|
|
info.push_back("handshake_start");
|
|
if (where & SSL_CB_HANDSHAKE_DONE)
|
|
info.push_back("handshake_done");
|
|
|
|
VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
|
|
<< ", with status: " << ret << reason;
|
|
}
|
|
|
|
// Static variable to store the index of TlsStream private data in SSL context
|
|
// used to store custom data for OnCertVerifyResults().
|
|
int ssl_ctx_private_data_index = -1;
|
|
|
|
// Default trusted certificate store location.
|
|
const char kCACertificatePath[] =
|
|
#ifdef __ANDROID__
|
|
"/system/etc/security/cacerts_google";
|
|
#else
|
|
"/usr/share/chromeos-ca-certificates";
|
|
#endif
|
|
|
|
} // anonymous namespace
|
|
|
|
namespace brillo {
|
|
|
|
// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped.
|
|
#if OPENSSL_VERSION_NUMBER < 0x10100000L
|
|
#define TLS_client_method() TLSv1_2_client_method()
|
|
#endif
|
|
|
|
// Helper implementation of TLS stream used to hide most of OpenSSL inner
|
|
// workings from the users of brillo::TlsStream.
|
|
class TlsStream::TlsStreamImpl {
|
|
public:
|
|
TlsStreamImpl();
|
|
~TlsStreamImpl();
|
|
|
|
bool Init(StreamPtr socket,
|
|
const std::string& host,
|
|
const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback,
|
|
ErrorPtr* error);
|
|
|
|
bool ReadNonBlocking(void* buffer,
|
|
size_t size_to_read,
|
|
size_t* size_read,
|
|
bool* end_of_stream,
|
|
ErrorPtr* error);
|
|
|
|
bool WriteNonBlocking(const void* buffer,
|
|
size_t size_to_write,
|
|
size_t* size_written,
|
|
ErrorPtr* error);
|
|
|
|
bool Flush(ErrorPtr* error);
|
|
bool Close(ErrorPtr* error);
|
|
bool WaitForData(AccessMode mode,
|
|
const base::Callback<void(AccessMode)>& callback,
|
|
ErrorPtr* error);
|
|
bool WaitForDataBlocking(AccessMode in_mode,
|
|
base::TimeDelta timeout,
|
|
AccessMode* out_mode,
|
|
ErrorPtr* error);
|
|
void CancelPendingAsyncOperations();
|
|
|
|
private:
|
|
bool ReportError(ErrorPtr* error,
|
|
const base::Location& location,
|
|
const std::string& message);
|
|
void DoHandshake(const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback);
|
|
void RetryHandshake(const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback,
|
|
Stream::AccessMode mode);
|
|
|
|
int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
|
|
static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
|
|
|
|
StreamPtr socket_;
|
|
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
|
|
std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
|
|
BIO* stream_bio_{nullptr};
|
|
bool need_more_read_{false};
|
|
bool need_more_write_{false};
|
|
|
|
base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
|
|
DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
|
|
};
|
|
|
|
TlsStream::TlsStreamImpl::TlsStreamImpl() {
|
|
SSL_load_error_strings();
|
|
SSL_library_init();
|
|
if (ssl_ctx_private_data_index < 0) {
|
|
ssl_ctx_private_data_index =
|
|
SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
|
|
}
|
|
}
|
|
|
|
TlsStream::TlsStreamImpl::~TlsStreamImpl() {
|
|
ssl_.reset();
|
|
ctx_.reset();
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
|
|
size_t size_to_read,
|
|
size_t* size_read,
|
|
bool* end_of_stream,
|
|
ErrorPtr* error) {
|
|
const size_t max_int = std::numeric_limits<int>::max();
|
|
int size_int = static_cast<int>(std::min(size_to_read, max_int));
|
|
int ret = SSL_read(ssl_.get(), buffer, size_int);
|
|
if (ret > 0) {
|
|
*size_read = static_cast<size_t>(ret);
|
|
if (end_of_stream)
|
|
*end_of_stream = false;
|
|
return true;
|
|
}
|
|
|
|
int err = SSL_get_error(ssl_.get(), ret);
|
|
if (err == SSL_ERROR_ZERO_RETURN) {
|
|
*size_read = 0;
|
|
if (end_of_stream)
|
|
*end_of_stream = true;
|
|
return true;
|
|
}
|
|
|
|
if (err == SSL_ERROR_WANT_READ) {
|
|
need_more_read_ = true;
|
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
|
// Writes might be required for SSL_read() because of possible TLS
|
|
// re-negotiations which can happen at any time.
|
|
need_more_write_ = true;
|
|
} else {
|
|
return ReportError(error, FROM_HERE, "Error reading from TLS socket");
|
|
}
|
|
*size_read = 0;
|
|
if (end_of_stream)
|
|
*end_of_stream = false;
|
|
return true;
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
|
|
size_t size_to_write,
|
|
size_t* size_written,
|
|
ErrorPtr* error) {
|
|
const size_t max_int = std::numeric_limits<int>::max();
|
|
int size_int = static_cast<int>(std::min(size_to_write, max_int));
|
|
int ret = SSL_write(ssl_.get(), buffer, size_int);
|
|
if (ret > 0) {
|
|
*size_written = static_cast<size_t>(ret);
|
|
return true;
|
|
}
|
|
|
|
int err = SSL_get_error(ssl_.get(), ret);
|
|
if (err == SSL_ERROR_WANT_READ) {
|
|
// Reads might be required for SSL_write() because of possible TLS
|
|
// re-negotiations which can happen at any time.
|
|
need_more_read_ = true;
|
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
|
need_more_write_ = true;
|
|
} else {
|
|
return ReportError(error, FROM_HERE, "Error writing to TLS socket");
|
|
}
|
|
*size_written = 0;
|
|
return true;
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
|
|
return socket_->FlushBlocking(error);
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
|
|
// 2 seconds should be plenty here.
|
|
const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
|
|
// The retry count of 4 below is just arbitrary, to ensure we don't get stuck
|
|
// here forever. We should rarely need to repeat SSL_shutdown anyway.
|
|
for (int retry_count = 0; retry_count < 4; retry_count++) {
|
|
int ret = SSL_shutdown(ssl_.get());
|
|
// We really don't care for bi-directional shutdown here.
|
|
// Just make sure we only send the "close notify" alert to the remote peer.
|
|
if (ret >= 0)
|
|
break;
|
|
|
|
int err = SSL_get_error(ssl_.get(), ret);
|
|
if (err == SSL_ERROR_WANT_READ) {
|
|
if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
|
|
error)) {
|
|
break;
|
|
}
|
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
|
if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
|
|
error)) {
|
|
break;
|
|
}
|
|
} else {
|
|
LOG(ERROR) << "SSL_shutdown returned error #" << err;
|
|
ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
|
|
break;
|
|
}
|
|
}
|
|
return socket_->CloseBlocking(error);
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::WaitForData(
|
|
AccessMode mode,
|
|
const base::Callback<void(AccessMode)>& callback,
|
|
ErrorPtr* error) {
|
|
bool is_read = stream_utils::IsReadAccessMode(mode);
|
|
bool is_write = stream_utils::IsWriteAccessMode(mode);
|
|
is_read |= need_more_read_;
|
|
is_write |= need_more_write_;
|
|
need_more_read_ = false;
|
|
need_more_write_ = false;
|
|
if (is_read && SSL_pending(ssl_.get()) > 0) {
|
|
callback.Run(AccessMode::READ);
|
|
return true;
|
|
}
|
|
mode = stream_utils::MakeAccessMode(is_read, is_write);
|
|
return socket_->WaitForData(mode, callback, error);
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
|
|
base::TimeDelta timeout,
|
|
AccessMode* out_mode,
|
|
ErrorPtr* error) {
|
|
bool is_read = stream_utils::IsReadAccessMode(in_mode);
|
|
bool is_write = stream_utils::IsWriteAccessMode(in_mode);
|
|
is_read |= need_more_read_;
|
|
is_write |= need_more_write_;
|
|
need_more_read_ = need_more_write_ = false;
|
|
if (is_read && SSL_pending(ssl_.get()) > 0) {
|
|
if (out_mode)
|
|
*out_mode = AccessMode::READ;
|
|
return true;
|
|
}
|
|
in_mode = stream_utils::MakeAccessMode(is_read, is_write);
|
|
return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
|
|
}
|
|
|
|
void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
|
|
socket_->CancelPendingAsyncOperations();
|
|
weak_ptr_factory_.InvalidateWeakPtrs();
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::ReportError(
|
|
ErrorPtr* error,
|
|
const base::Location& location,
|
|
const std::string& message) {
|
|
const char* file = nullptr;
|
|
int line = 0;
|
|
const char* data = 0;
|
|
int flags = 0;
|
|
while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
|
|
char buf[256];
|
|
ERR_error_string_n(errnum, buf, sizeof(buf));
|
|
base::Location ssl_location{"Unknown", file, line, nullptr};
|
|
std::string ssl_message = buf;
|
|
if (flags & ERR_TXT_STRING) {
|
|
ssl_message += ": ";
|
|
ssl_message += data;
|
|
}
|
|
Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
|
|
ssl_message);
|
|
}
|
|
Error::AddTo(error, location, "tls_stream", "failed", message);
|
|
return false;
|
|
}
|
|
|
|
int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
|
|
// OpenSSL already performs a comprehensive check of the certificate chain
|
|
// (using X509_verify_cert() function) and calls back with the result of its
|
|
// verification.
|
|
// |ok| is set to 1 if the verification passed and 0 if an error was detected.
|
|
// Here we can perform some additional checks if we need to, or simply log
|
|
// the issues found.
|
|
|
|
// For now, just log an error if it occurred.
|
|
if (!ok) {
|
|
LOG(ERROR) << "Server certificate validation failed: "
|
|
<< X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
|
|
}
|
|
return ok;
|
|
}
|
|
|
|
int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
|
|
X509_STORE_CTX* ctx) {
|
|
// Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
|
|
// SSL CTX object referenced by |ctx|.
|
|
SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
|
|
ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
|
|
SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
|
|
TlsStream::TlsStreamImpl* self = nullptr;
|
|
if (ssl_ctx) {
|
|
self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
|
|
ssl_ctx, ssl_ctx_private_data_index));
|
|
}
|
|
return self ? self->OnCertVerifyResults(ok, ctx) : ok;
|
|
}
|
|
|
|
bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
|
|
const std::string& host,
|
|
const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback,
|
|
ErrorPtr* error) {
|
|
ctx_.reset(SSL_CTX_new(TLS_client_method()));
|
|
if (!ctx_)
|
|
return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
|
|
|
|
// Top cipher suites supported by both Google GFEs and OpenSSL (in server
|
|
// preferred order).
|
|
int res = SSL_CTX_set_cipher_list(ctx_.get(),
|
|
"ECDHE-ECDSA-AES128-GCM-SHA256:"
|
|
"ECDHE-ECDSA-AES256-GCM-SHA384:"
|
|
"ECDHE-RSA-AES128-GCM-SHA256:"
|
|
"ECDHE-RSA-AES256-GCM-SHA384");
|
|
if (res != 1)
|
|
return ReportError(error, FROM_HERE, "Cannot set the cipher list");
|
|
|
|
res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
|
|
if (res != 1) {
|
|
return ReportError(error, FROM_HERE,
|
|
"Failed to specify trusted certificate location");
|
|
}
|
|
|
|
// Store a pointer to "this" into SSL_CTX instance.
|
|
SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
|
|
|
|
// Ask OpenSSL to validate the server host from the certificate to match
|
|
// the expected host name we are given:
|
|
X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
|
|
X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
|
|
|
|
SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
|
|
&TlsStreamImpl::OnCertVerifyResultsStatic);
|
|
|
|
socket_ = std::move(socket);
|
|
ssl_.reset(SSL_new(ctx_.get()));
|
|
|
|
// Enable TLS progress callback if VLOG level is >=3.
|
|
if (VLOG_IS_ON(3))
|
|
SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
|
|
|
|
stream_bio_ = BIO_new_stream(socket_.get());
|
|
SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
|
|
SSL_set_connect_state(ssl_.get());
|
|
|
|
// We might have no message loop (e.g. we are in unit tests).
|
|
if (MessageLoop::ThreadHasCurrent()) {
|
|
MessageLoop::current()->PostTask(
|
|
FROM_HERE,
|
|
base::BindOnce(&TlsStreamImpl::DoHandshake,
|
|
weak_ptr_factory_.GetWeakPtr(),
|
|
success_callback,
|
|
error_callback));
|
|
} else {
|
|
DoHandshake(success_callback, error_callback);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void TlsStream::TlsStreamImpl::RetryHandshake(
|
|
const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback,
|
|
Stream::AccessMode /* mode */) {
|
|
VLOG(1) << "Retrying TLS handshake";
|
|
DoHandshake(success_callback, error_callback);
|
|
}
|
|
|
|
void TlsStream::TlsStreamImpl::DoHandshake(
|
|
const base::Closure& success_callback,
|
|
const Stream::ErrorCallback& error_callback) {
|
|
VLOG(1) << "Begin TLS handshake";
|
|
int res = SSL_do_handshake(ssl_.get());
|
|
if (res == 1) {
|
|
VLOG(1) << "Handshake successful";
|
|
success_callback.Run();
|
|
return;
|
|
}
|
|
ErrorPtr error;
|
|
int err = SSL_get_error(ssl_.get(), res);
|
|
if (err == SSL_ERROR_WANT_READ) {
|
|
VLOG(1) << "Waiting for read data...";
|
|
bool ok = socket_->WaitForData(
|
|
Stream::AccessMode::READ,
|
|
base::Bind(&TlsStreamImpl::RetryHandshake,
|
|
weak_ptr_factory_.GetWeakPtr(),
|
|
success_callback, error_callback),
|
|
&error);
|
|
if (ok)
|
|
return;
|
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
|
VLOG(1) << "Waiting for write data...";
|
|
bool ok = socket_->WaitForData(
|
|
Stream::AccessMode::WRITE,
|
|
base::Bind(&TlsStreamImpl::RetryHandshake,
|
|
weak_ptr_factory_.GetWeakPtr(),
|
|
success_callback, error_callback),
|
|
&error);
|
|
if (ok)
|
|
return;
|
|
} else {
|
|
ReportError(&error, FROM_HERE, "TLS handshake failed.");
|
|
}
|
|
error_callback.Run(error.get());
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
|
|
: impl_{std::move(impl)} {}
|
|
|
|
TlsStream::~TlsStream() {
|
|
if (impl_) {
|
|
impl_->Close(nullptr);
|
|
}
|
|
}
|
|
|
|
void TlsStream::Connect(StreamPtr socket,
|
|
const std::string& host,
|
|
const base::Callback<void(StreamPtr)>& success_callback,
|
|
const Stream::ErrorCallback& error_callback) {
|
|
std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
|
|
std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
|
|
|
|
TlsStreamImpl* pimpl = stream->impl_.get();
|
|
ErrorPtr error;
|
|
bool success = pimpl->Init(std::move(socket), host,
|
|
base::Bind(success_callback,
|
|
base::Passed(std::move(stream))),
|
|
error_callback, &error);
|
|
|
|
if (!success)
|
|
error_callback.Run(error.get());
|
|
}
|
|
|
|
bool TlsStream::IsOpen() const {
|
|
return impl_ ? true : false;
|
|
}
|
|
|
|
bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
|
|
return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
|
|
}
|
|
|
|
bool TlsStream::Seek(int64_t /* offset */,
|
|
Whence /* whence */,
|
|
uint64_t* /* new_position*/,
|
|
ErrorPtr* error) {
|
|
return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
|
|
}
|
|
|
|
bool TlsStream::ReadNonBlocking(void* buffer,
|
|
size_t size_to_read,
|
|
size_t* size_read,
|
|
bool* end_of_stream,
|
|
ErrorPtr* error) {
|
|
if (!impl_)
|
|
return stream_utils::ErrorStreamClosed(FROM_HERE, error);
|
|
return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
|
|
error);
|
|
}
|
|
|
|
bool TlsStream::WriteNonBlocking(const void* buffer,
|
|
size_t size_to_write,
|
|
size_t* size_written,
|
|
ErrorPtr* error) {
|
|
if (!impl_)
|
|
return stream_utils::ErrorStreamClosed(FROM_HERE, error);
|
|
return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
|
|
}
|
|
|
|
bool TlsStream::FlushBlocking(ErrorPtr* error) {
|
|
if (!impl_)
|
|
return stream_utils::ErrorStreamClosed(FROM_HERE, error);
|
|
return impl_->Flush(error);
|
|
}
|
|
|
|
bool TlsStream::CloseBlocking(ErrorPtr* error) {
|
|
if (impl_ && !impl_->Close(error))
|
|
return false;
|
|
impl_.reset();
|
|
return true;
|
|
}
|
|
|
|
bool TlsStream::WaitForData(AccessMode mode,
|
|
const base::Callback<void(AccessMode)>& callback,
|
|
ErrorPtr* error) {
|
|
if (!impl_)
|
|
return stream_utils::ErrorStreamClosed(FROM_HERE, error);
|
|
return impl_->WaitForData(mode, callback, error);
|
|
}
|
|
|
|
bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
|
|
base::TimeDelta timeout,
|
|
AccessMode* out_mode,
|
|
ErrorPtr* error) {
|
|
if (!impl_)
|
|
return stream_utils::ErrorStreamClosed(FROM_HERE, error);
|
|
return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
|
|
}
|
|
|
|
void TlsStream::CancelPendingAsyncOperations() {
|
|
if (impl_)
|
|
impl_->CancelPendingAsyncOperations();
|
|
Stream::CancelPendingAsyncOperations();
|
|
}
|
|
|
|
} // namespace brillo
|