You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
492 lines
18 KiB
492 lines
18 KiB
/*
|
|
* Copyright (C) 2020 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "adb/pairing/pairing_connection.h"
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <string_view>
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
#include <adb/pairing/pairing_auth.h>
|
|
#include <adb/tls/tls_connection.h>
|
|
#include <android-base/endian.h>
|
|
#include <android-base/logging.h>
|
|
#include <android-base/macros.h>
|
|
#include <android-base/unique_fd.h>
|
|
|
|
#include "pairing.pb.h"
|
|
|
|
using namespace adb;
|
|
using android::base::unique_fd;
|
|
using TlsError = tls::TlsConnection::TlsError;
|
|
|
|
const uint8_t kCurrentKeyHeaderVersion = 1;
|
|
const uint8_t kMinSupportedKeyHeaderVersion = 1;
|
|
const uint8_t kMaxSupportedKeyHeaderVersion = 1;
|
|
const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
|
|
|
|
struct PairingPacketHeader {
|
|
uint8_t version; // PairingPacket version
|
|
uint8_t type; // the type of packet (PairingPacket.Type)
|
|
uint32_t payload; // Size of the payload in bytes
|
|
} __attribute__((packed));
|
|
|
|
struct PairingAuthDeleter {
|
|
void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
|
|
}; // PairingAuthDeleter
|
|
using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
|
|
|
|
// PairingConnectionCtx encapsulates the protocol to authenticate two peers with
|
|
// each other. This class will open the tcp sockets and handle the pairing
|
|
// process. On completion, both sides will have each other's public key
|
|
// (certificate) if successful, otherwise, the pairing failed. The tcp port
|
|
// number is hardcoded (see pairing_connection.cpp).
|
|
//
|
|
// Each PairingConnectionCtx instance represents a different device trying to
|
|
// pair. So for the device, we can have multiple PairingConnectionCtxs while the
|
|
// host may have only one (unless host has a PairingServer).
|
|
//
|
|
// See pairing_connection_test.cpp for example usage.
|
|
//
|
|
struct PairingConnectionCtx {
|
|
public:
|
|
using Data = std::vector<uint8_t>;
|
|
using ResultCallback = pairing_result_cb;
|
|
enum class Role {
|
|
Client,
|
|
Server,
|
|
};
|
|
|
|
explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
|
|
const Data& certificate, const Data& priv_key);
|
|
virtual ~PairingConnectionCtx();
|
|
|
|
// Starts the pairing connection on a separate thread.
|
|
// Upon completion, if the pairing was successful,
|
|
// |cb| will be called with the peer information and certificate.
|
|
// Otherwise, |cb| will be called with empty data. |fd| should already
|
|
// be opened. PairingConnectionCtx will take ownership of the |fd|.
|
|
//
|
|
// Pairing is successful if both server/client uses the same non-empty
|
|
// |pswd|, and they are able to exchange the information. |pswd| and
|
|
// |certificate| must be non-empty. Start() can only be called once in the
|
|
// lifetime of this object.
|
|
//
|
|
// Returns true if the thread was successfully started, false otherwise.
|
|
bool Start(int fd, ResultCallback cb, void* opaque);
|
|
|
|
private:
|
|
// Setup the tls connection.
|
|
bool SetupTlsConnection();
|
|
|
|
/************ PairingPacketHeader methods ****************/
|
|
// Tries to write out the header and payload.
|
|
bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
|
|
// Tries to parse incoming data into the |header|. Returns true if header
|
|
// is valid and header version is supported. |header| is filled on success.
|
|
// |header| may contain garbage if unsuccessful.
|
|
bool ReadHeader(PairingPacketHeader* header);
|
|
// Creates a PairingPacketHeader.
|
|
void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
|
|
uint32_t payload_size);
|
|
// Checks if actual matches expected.
|
|
bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
|
|
|
|
/*********** State related methods **************/
|
|
// Handles the State::ExchangingMsgs state.
|
|
bool DoExchangeMsgs();
|
|
// Handles the State::ExchangingPeerInfo state.
|
|
bool DoExchangePeerInfo();
|
|
|
|
// The background task to do the pairing.
|
|
void StartWorker();
|
|
|
|
// Calls |cb_| and sets the state to Stopped.
|
|
void NotifyResult(const PeerInfo* p);
|
|
|
|
static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
|
|
|
|
enum class State {
|
|
Ready,
|
|
ExchangingMsgs,
|
|
ExchangingPeerInfo,
|
|
Stopped,
|
|
};
|
|
|
|
std::atomic<State> state_{State::Ready};
|
|
Role role_;
|
|
Data pswd_;
|
|
PeerInfo peer_info_;
|
|
Data cert_;
|
|
Data priv_key_;
|
|
|
|
// Peer's info
|
|
PeerInfo their_info_;
|
|
|
|
ResultCallback cb_;
|
|
void* opaque_ = nullptr;
|
|
std::unique_ptr<tls::TlsConnection> tls_;
|
|
PairingAuthPtr auth_;
|
|
unique_fd fd_;
|
|
std::thread thread_;
|
|
static constexpr size_t kExportedKeySize = 64;
|
|
}; // PairingConnectionCtx
|
|
|
|
PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
|
|
const Data& cert, const Data& priv_key)
|
|
: role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
|
|
CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
|
|
}
|
|
|
|
PairingConnectionCtx::~PairingConnectionCtx() {
|
|
// Force close the fd and wait for the worker thread to finish.
|
|
fd_.reset();
|
|
if (thread_.joinable()) {
|
|
thread_.join();
|
|
}
|
|
}
|
|
|
|
bool PairingConnectionCtx::SetupTlsConnection() {
|
|
tls_ = tls::TlsConnection::Create(
|
|
role_ == Role::Server ? tls::TlsConnection::Role::Server
|
|
: tls::TlsConnection::Role::Client,
|
|
std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
|
|
std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
|
|
fd_);
|
|
|
|
if (tls_ == nullptr) {
|
|
LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
|
|
return false;
|
|
}
|
|
|
|
// Allow any peer certificate
|
|
tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
|
|
|
|
// SSL doesn't seem to behave correctly with fdevents so just do a blocking
|
|
// read for the pairing data.
|
|
if (tls_->DoHandshake() != TlsError::Success) {
|
|
LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
|
|
return false;
|
|
}
|
|
|
|
// To ensure the connection is not stolen while we do the PAKE, append the
|
|
// exported key material from the tls connection to the password.
|
|
std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
|
|
if (exportedKeyMaterial.empty()) {
|
|
LOG(ERROR) << "Failed to export key material";
|
|
return false;
|
|
}
|
|
pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
|
|
std::make_move_iterator(exportedKeyMaterial.end()));
|
|
auth_ = CreatePairingAuthPtr(role_, pswd_);
|
|
|
|
return true;
|
|
}
|
|
|
|
bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
|
|
std::string_view payload) {
|
|
PairingPacketHeader network_header = *header;
|
|
network_header.payload = htonl(network_header.payload);
|
|
if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
|
|
sizeof(PairingPacketHeader))) ||
|
|
!tls_->WriteFully(payload)) {
|
|
LOG(ERROR) << "Failed to write out PairingPacketHeader";
|
|
state_ = State::Stopped;
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
|
|
auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
|
|
if (data.empty()) {
|
|
return false;
|
|
}
|
|
|
|
uint8_t* p = data.data();
|
|
// First byte is always PairingPacketHeader version
|
|
header->version = *p;
|
|
++p;
|
|
if (header->version < kMinSupportedKeyHeaderVersion ||
|
|
header->version > kMaxSupportedKeyHeaderVersion) {
|
|
LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
|
|
<< " them=" << header->version << ")";
|
|
return false;
|
|
}
|
|
// Next byte is the PairingPacket::Type
|
|
if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
|
|
LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
|
|
return false;
|
|
}
|
|
header->type = *p;
|
|
++p;
|
|
// Last, the payload size
|
|
header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
|
|
if (header->payload == 0 || header->payload > kMaxPayloadSize) {
|
|
LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
|
|
<< ")";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
|
|
adb::proto::PairingPacket::Type type,
|
|
uint32_t payload_size) {
|
|
header->version = kCurrentKeyHeaderVersion;
|
|
uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
|
|
header->type = type8;
|
|
header->payload = payload_size;
|
|
}
|
|
|
|
bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
|
|
uint8_t actual) {
|
|
uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
|
|
if (actual != expected) {
|
|
LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
|
|
<< " actual=" << static_cast<uint32_t>(actual) << ")";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
|
|
cb_(p, fd_.get(), opaque_);
|
|
state_ = State::Stopped;
|
|
}
|
|
|
|
bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
|
|
if (fd < 0) {
|
|
return false;
|
|
}
|
|
fd_.reset(fd);
|
|
|
|
State expected = State::Ready;
|
|
if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
|
|
return false;
|
|
}
|
|
|
|
cb_ = cb;
|
|
opaque_ = opaque;
|
|
|
|
thread_ = std::thread([this] { StartWorker(); });
|
|
return true;
|
|
}
|
|
|
|
bool PairingConnectionCtx::DoExchangeMsgs() {
|
|
uint32_t payload = pairing_auth_msg_size(auth_.get());
|
|
std::vector<uint8_t> msg(payload);
|
|
pairing_auth_get_spake2_msg(auth_.get(), msg.data());
|
|
|
|
PairingPacketHeader header;
|
|
CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
|
|
|
|
// Write our SPAKE2 msg
|
|
if (!WriteHeader(&header,
|
|
std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
|
|
LOG(ERROR) << "Failed to write SPAKE2 msg.";
|
|
return false;
|
|
}
|
|
|
|
// Read the peer's SPAKE2 msg header
|
|
if (!ReadHeader(&header)) {
|
|
LOG(ERROR) << "Invalid PairingPacketHeader.";
|
|
return false;
|
|
}
|
|
if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
|
|
return false;
|
|
}
|
|
|
|
// Read the SPAKE2 msg payload and initialize the cipher for
|
|
// encrypting the PeerInfo and certificate.
|
|
auto their_msg = tls_->ReadFully(header.payload);
|
|
if (their_msg.empty() ||
|
|
!pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
|
|
LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
|
|
<< "]";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool PairingConnectionCtx::DoExchangePeerInfo() {
|
|
// Encrypt PeerInfo
|
|
std::vector<uint8_t> buf;
|
|
uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
|
|
buf.assign(p, p + sizeof(peer_info_));
|
|
std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
|
|
CHECK(!outbuf.empty());
|
|
size_t outsize;
|
|
if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
|
|
LOG(ERROR) << "Failed to encrypt peer info";
|
|
return false;
|
|
}
|
|
outbuf.resize(outsize);
|
|
|
|
// Write out the packet header
|
|
PairingPacketHeader out_header;
|
|
out_header.version = kCurrentKeyHeaderVersion;
|
|
out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
|
|
out_header.payload = htonl(outbuf.size());
|
|
if (!tls_->WriteFully(
|
|
std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
|
|
LOG(ERROR) << "Unable to write PairingPacketHeader";
|
|
return false;
|
|
}
|
|
|
|
// Write out the encrypted payload
|
|
if (!tls_->WriteFully(
|
|
std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
|
|
LOG(ERROR) << "Unable to write encrypted peer info";
|
|
return false;
|
|
}
|
|
|
|
// Read in the peer's packet header
|
|
PairingPacketHeader header;
|
|
if (!ReadHeader(&header)) {
|
|
LOG(ERROR) << "Invalid PairingPacketHeader.";
|
|
return false;
|
|
}
|
|
|
|
if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
|
|
return false;
|
|
}
|
|
|
|
// Read in the encrypted peer certificate
|
|
buf = tls_->ReadFully(header.payload);
|
|
if (buf.empty()) {
|
|
return false;
|
|
}
|
|
|
|
// Try to decrypt the certificate
|
|
outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
|
|
if (outbuf.empty()) {
|
|
LOG(ERROR) << "Unsupported payload while decrypting peer info.";
|
|
return false;
|
|
}
|
|
|
|
if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
|
|
LOG(ERROR) << "Failed to decrypt";
|
|
return false;
|
|
}
|
|
outbuf.resize(outsize);
|
|
|
|
// The decrypted message should contain the PeerInfo.
|
|
if (outbuf.size() != sizeof(PeerInfo)) {
|
|
LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
|
|
return false;
|
|
}
|
|
|
|
p = outbuf.data();
|
|
::memcpy(&their_info_, p, sizeof(PeerInfo));
|
|
p += sizeof(PeerInfo);
|
|
|
|
return true;
|
|
}
|
|
|
|
void PairingConnectionCtx::StartWorker() {
|
|
// Setup the secure transport
|
|
if (!SetupTlsConnection()) {
|
|
NotifyResult(nullptr);
|
|
return;
|
|
}
|
|
|
|
for (;;) {
|
|
switch (state_) {
|
|
case State::ExchangingMsgs:
|
|
if (!DoExchangeMsgs()) {
|
|
NotifyResult(nullptr);
|
|
return;
|
|
}
|
|
state_ = State::ExchangingPeerInfo;
|
|
break;
|
|
case State::ExchangingPeerInfo:
|
|
if (!DoExchangePeerInfo()) {
|
|
NotifyResult(nullptr);
|
|
return;
|
|
}
|
|
NotifyResult(&their_info_);
|
|
return;
|
|
case State::Ready:
|
|
case State::Stopped:
|
|
LOG(FATAL) << __func__ << ": Got invalid state";
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// static
|
|
PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
|
|
switch (role) {
|
|
case Role::Client:
|
|
return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
|
|
break;
|
|
case Role::Server:
|
|
return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
|
|
break;
|
|
}
|
|
}
|
|
|
|
static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
|
|
size_t pswd_len, const PeerInfo* peer_info,
|
|
const uint8_t* x509_cert_pem, size_t x509_size,
|
|
const uint8_t* priv_key_pem, size_t priv_size) {
|
|
CHECK(pswd);
|
|
CHECK_GT(pswd_len, 0U);
|
|
CHECK(x509_cert_pem);
|
|
CHECK_GT(x509_size, 0U);
|
|
CHECK(priv_key_pem);
|
|
CHECK_GT(priv_size, 0U);
|
|
CHECK(peer_info);
|
|
std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
|
|
std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
|
|
std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
|
|
return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
|
|
}
|
|
|
|
PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
|
|
const PeerInfo* peer_info,
|
|
const uint8_t* x509_cert_pem, size_t x509_size,
|
|
const uint8_t* priv_key_pem, size_t priv_size) {
|
|
return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
|
|
x509_cert_pem, x509_size, priv_key_pem, priv_size);
|
|
}
|
|
|
|
PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
|
|
const PeerInfo* peer_info,
|
|
const uint8_t* x509_cert_pem, size_t x509_size,
|
|
const uint8_t* priv_key_pem, size_t priv_size) {
|
|
return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
|
|
x509_cert_pem, x509_size, priv_key_pem, priv_size);
|
|
}
|
|
|
|
void pairing_connection_destroy(PairingConnectionCtx* ctx) {
|
|
CHECK(ctx);
|
|
delete ctx;
|
|
}
|
|
|
|
bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
|
|
void* opaque) {
|
|
return ctx->Start(fd, cb, opaque);
|
|
}
|