You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

492 lines
18 KiB

/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "adb/pairing/pairing_connection.h"
#include <stddef.h>
#include <stdint.h>
#include <functional>
#include <memory>
#include <string_view>
#include <thread>
#include <vector>
#include <adb/pairing/pairing_auth.h>
#include <adb/tls/tls_connection.h>
#include <android-base/endian.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
#include <android-base/unique_fd.h>
#include "pairing.pb.h"
using namespace adb;
using android::base::unique_fd;
using TlsError = tls::TlsConnection::TlsError;
const uint8_t kCurrentKeyHeaderVersion = 1;
const uint8_t kMinSupportedKeyHeaderVersion = 1;
const uint8_t kMaxSupportedKeyHeaderVersion = 1;
const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
struct PairingPacketHeader {
uint8_t version; // PairingPacket version
uint8_t type; // the type of packet (PairingPacket.Type)
uint32_t payload; // Size of the payload in bytes
} __attribute__((packed));
struct PairingAuthDeleter {
void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
}; // PairingAuthDeleter
using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
// PairingConnectionCtx encapsulates the protocol to authenticate two peers with
// each other. This class will open the tcp sockets and handle the pairing
// process. On completion, both sides will have each other's public key
// (certificate) if successful, otherwise, the pairing failed. The tcp port
// number is hardcoded (see pairing_connection.cpp).
//
// Each PairingConnectionCtx instance represents a different device trying to
// pair. So for the device, we can have multiple PairingConnectionCtxs while the
// host may have only one (unless host has a PairingServer).
//
// See pairing_connection_test.cpp for example usage.
//
struct PairingConnectionCtx {
public:
using Data = std::vector<uint8_t>;
using ResultCallback = pairing_result_cb;
enum class Role {
Client,
Server,
};
explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
const Data& certificate, const Data& priv_key);
virtual ~PairingConnectionCtx();
// Starts the pairing connection on a separate thread.
// Upon completion, if the pairing was successful,
// |cb| will be called with the peer information and certificate.
// Otherwise, |cb| will be called with empty data. |fd| should already
// be opened. PairingConnectionCtx will take ownership of the |fd|.
//
// Pairing is successful if both server/client uses the same non-empty
// |pswd|, and they are able to exchange the information. |pswd| and
// |certificate| must be non-empty. Start() can only be called once in the
// lifetime of this object.
//
// Returns true if the thread was successfully started, false otherwise.
bool Start(int fd, ResultCallback cb, void* opaque);
private:
// Setup the tls connection.
bool SetupTlsConnection();
/************ PairingPacketHeader methods ****************/
// Tries to write out the header and payload.
bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
// Tries to parse incoming data into the |header|. Returns true if header
// is valid and header version is supported. |header| is filled on success.
// |header| may contain garbage if unsuccessful.
bool ReadHeader(PairingPacketHeader* header);
// Creates a PairingPacketHeader.
void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
uint32_t payload_size);
// Checks if actual matches expected.
bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
/*********** State related methods **************/
// Handles the State::ExchangingMsgs state.
bool DoExchangeMsgs();
// Handles the State::ExchangingPeerInfo state.
bool DoExchangePeerInfo();
// The background task to do the pairing.
void StartWorker();
// Calls |cb_| and sets the state to Stopped.
void NotifyResult(const PeerInfo* p);
static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
enum class State {
Ready,
ExchangingMsgs,
ExchangingPeerInfo,
Stopped,
};
std::atomic<State> state_{State::Ready};
Role role_;
Data pswd_;
PeerInfo peer_info_;
Data cert_;
Data priv_key_;
// Peer's info
PeerInfo their_info_;
ResultCallback cb_;
void* opaque_ = nullptr;
std::unique_ptr<tls::TlsConnection> tls_;
PairingAuthPtr auth_;
unique_fd fd_;
std::thread thread_;
static constexpr size_t kExportedKeySize = 64;
}; // PairingConnectionCtx
PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
const Data& cert, const Data& priv_key)
: role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
}
PairingConnectionCtx::~PairingConnectionCtx() {
// Force close the fd and wait for the worker thread to finish.
fd_.reset();
if (thread_.joinable()) {
thread_.join();
}
}
bool PairingConnectionCtx::SetupTlsConnection() {
tls_ = tls::TlsConnection::Create(
role_ == Role::Server ? tls::TlsConnection::Role::Server
: tls::TlsConnection::Role::Client,
std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
fd_);
if (tls_ == nullptr) {
LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
return false;
}
// Allow any peer certificate
tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
// SSL doesn't seem to behave correctly with fdevents so just do a blocking
// read for the pairing data.
if (tls_->DoHandshake() != TlsError::Success) {
LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
return false;
}
// To ensure the connection is not stolen while we do the PAKE, append the
// exported key material from the tls connection to the password.
std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
if (exportedKeyMaterial.empty()) {
LOG(ERROR) << "Failed to export key material";
return false;
}
pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
std::make_move_iterator(exportedKeyMaterial.end()));
auth_ = CreatePairingAuthPtr(role_, pswd_);
return true;
}
bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
std::string_view payload) {
PairingPacketHeader network_header = *header;
network_header.payload = htonl(network_header.payload);
if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
sizeof(PairingPacketHeader))) ||
!tls_->WriteFully(payload)) {
LOG(ERROR) << "Failed to write out PairingPacketHeader";
state_ = State::Stopped;
return false;
}
return true;
}
bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
if (data.empty()) {
return false;
}
uint8_t* p = data.data();
// First byte is always PairingPacketHeader version
header->version = *p;
++p;
if (header->version < kMinSupportedKeyHeaderVersion ||
header->version > kMaxSupportedKeyHeaderVersion) {
LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
<< " them=" << header->version << ")";
return false;
}
// Next byte is the PairingPacket::Type
if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
return false;
}
header->type = *p;
++p;
// Last, the payload size
header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
if (header->payload == 0 || header->payload > kMaxPayloadSize) {
LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
<< ")";
return false;
}
return true;
}
void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
adb::proto::PairingPacket::Type type,
uint32_t payload_size) {
header->version = kCurrentKeyHeaderVersion;
uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
header->type = type8;
header->payload = payload_size;
}
bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
uint8_t actual) {
uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
if (actual != expected) {
LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
<< " actual=" << static_cast<uint32_t>(actual) << ")";
return false;
}
return true;
}
void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
cb_(p, fd_.get(), opaque_);
state_ = State::Stopped;
}
bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
if (fd < 0) {
return false;
}
fd_.reset(fd);
State expected = State::Ready;
if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
return false;
}
cb_ = cb;
opaque_ = opaque;
thread_ = std::thread([this] { StartWorker(); });
return true;
}
bool PairingConnectionCtx::DoExchangeMsgs() {
uint32_t payload = pairing_auth_msg_size(auth_.get());
std::vector<uint8_t> msg(payload);
pairing_auth_get_spake2_msg(auth_.get(), msg.data());
PairingPacketHeader header;
CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
// Write our SPAKE2 msg
if (!WriteHeader(&header,
std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
LOG(ERROR) << "Failed to write SPAKE2 msg.";
return false;
}
// Read the peer's SPAKE2 msg header
if (!ReadHeader(&header)) {
LOG(ERROR) << "Invalid PairingPacketHeader.";
return false;
}
if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
return false;
}
// Read the SPAKE2 msg payload and initialize the cipher for
// encrypting the PeerInfo and certificate.
auto their_msg = tls_->ReadFully(header.payload);
if (their_msg.empty() ||
!pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
<< "]";
return false;
}
return true;
}
bool PairingConnectionCtx::DoExchangePeerInfo() {
// Encrypt PeerInfo
std::vector<uint8_t> buf;
uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
buf.assign(p, p + sizeof(peer_info_));
std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
CHECK(!outbuf.empty());
size_t outsize;
if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
LOG(ERROR) << "Failed to encrypt peer info";
return false;
}
outbuf.resize(outsize);
// Write out the packet header
PairingPacketHeader out_header;
out_header.version = kCurrentKeyHeaderVersion;
out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
out_header.payload = htonl(outbuf.size());
if (!tls_->WriteFully(
std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
LOG(ERROR) << "Unable to write PairingPacketHeader";
return false;
}
// Write out the encrypted payload
if (!tls_->WriteFully(
std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
LOG(ERROR) << "Unable to write encrypted peer info";
return false;
}
// Read in the peer's packet header
PairingPacketHeader header;
if (!ReadHeader(&header)) {
LOG(ERROR) << "Invalid PairingPacketHeader.";
return false;
}
if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
return false;
}
// Read in the encrypted peer certificate
buf = tls_->ReadFully(header.payload);
if (buf.empty()) {
return false;
}
// Try to decrypt the certificate
outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
if (outbuf.empty()) {
LOG(ERROR) << "Unsupported payload while decrypting peer info.";
return false;
}
if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
LOG(ERROR) << "Failed to decrypt";
return false;
}
outbuf.resize(outsize);
// The decrypted message should contain the PeerInfo.
if (outbuf.size() != sizeof(PeerInfo)) {
LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
return false;
}
p = outbuf.data();
::memcpy(&their_info_, p, sizeof(PeerInfo));
p += sizeof(PeerInfo);
return true;
}
void PairingConnectionCtx::StartWorker() {
// Setup the secure transport
if (!SetupTlsConnection()) {
NotifyResult(nullptr);
return;
}
for (;;) {
switch (state_) {
case State::ExchangingMsgs:
if (!DoExchangeMsgs()) {
NotifyResult(nullptr);
return;
}
state_ = State::ExchangingPeerInfo;
break;
case State::ExchangingPeerInfo:
if (!DoExchangePeerInfo()) {
NotifyResult(nullptr);
return;
}
NotifyResult(&their_info_);
return;
case State::Ready:
case State::Stopped:
LOG(FATAL) << __func__ << ": Got invalid state";
return;
}
}
}
// static
PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
switch (role) {
case Role::Client:
return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
break;
case Role::Server:
return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
break;
}
}
static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
size_t pswd_len, const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
CHECK(pswd);
CHECK_GT(pswd_len, 0U);
CHECK(x509_cert_pem);
CHECK_GT(x509_size, 0U);
CHECK(priv_key_pem);
CHECK_GT(priv_size, 0U);
CHECK(peer_info);
std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
}
PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
x509_cert_pem, x509_size, priv_key_pem, priv_size);
}
PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
const PeerInfo* peer_info,
const uint8_t* x509_cert_pem, size_t x509_size,
const uint8_t* priv_key_pem, size_t priv_size) {
return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
x509_cert_pem, x509_size, priv_key_pem, priv_size);
}
void pairing_connection_destroy(PairingConnectionCtx* ctx) {
CHECK(ctx);
delete ctx;
}
bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
void* opaque) {
return ctx->Start(fd, cb, opaque);
}