/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "adb/pairing/pairing_connection.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "pairing.pb.h" using namespace adb; using android::base::unique_fd; using TlsError = tls::TlsConnection::TlsError; const uint8_t kCurrentKeyHeaderVersion = 1; const uint8_t kMinSupportedKeyHeaderVersion = 1; const uint8_t kMaxSupportedKeyHeaderVersion = 1; const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2; struct PairingPacketHeader { uint8_t version; // PairingPacket version uint8_t type; // the type of packet (PairingPacket.Type) uint32_t payload; // Size of the payload in bytes } __attribute__((packed)); struct PairingAuthDeleter { void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); } }; // PairingAuthDeleter using PairingAuthPtr = std::unique_ptr; // PairingConnectionCtx encapsulates the protocol to authenticate two peers with // each other. This class will open the tcp sockets and handle the pairing // process. On completion, both sides will have each other's public key // (certificate) if successful, otherwise, the pairing failed. The tcp port // number is hardcoded (see pairing_connection.cpp). // // Each PairingConnectionCtx instance represents a different device trying to // pair. So for the device, we can have multiple PairingConnectionCtxs while the // host may have only one (unless host has a PairingServer). // // See pairing_connection_test.cpp for example usage. // struct PairingConnectionCtx { public: using Data = std::vector; using ResultCallback = pairing_result_cb; enum class Role { Client, Server, }; explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, const Data& certificate, const Data& priv_key); virtual ~PairingConnectionCtx(); // Starts the pairing connection on a separate thread. // Upon completion, if the pairing was successful, // |cb| will be called with the peer information and certificate. // Otherwise, |cb| will be called with empty data. |fd| should already // be opened. PairingConnectionCtx will take ownership of the |fd|. // // Pairing is successful if both server/client uses the same non-empty // |pswd|, and they are able to exchange the information. |pswd| and // |certificate| must be non-empty. Start() can only be called once in the // lifetime of this object. // // Returns true if the thread was successfully started, false otherwise. bool Start(int fd, ResultCallback cb, void* opaque); private: // Setup the tls connection. bool SetupTlsConnection(); /************ PairingPacketHeader methods ****************/ // Tries to write out the header and payload. bool WriteHeader(const PairingPacketHeader* header, std::string_view payload); // Tries to parse incoming data into the |header|. Returns true if header // is valid and header version is supported. |header| is filled on success. // |header| may contain garbage if unsuccessful. bool ReadHeader(PairingPacketHeader* header); // Creates a PairingPacketHeader. void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type, uint32_t payload_size); // Checks if actual matches expected. bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual); /*********** State related methods **************/ // Handles the State::ExchangingMsgs state. bool DoExchangeMsgs(); // Handles the State::ExchangingPeerInfo state. bool DoExchangePeerInfo(); // The background task to do the pairing. void StartWorker(); // Calls |cb_| and sets the state to Stopped. void NotifyResult(const PeerInfo* p); static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd); enum class State { Ready, ExchangingMsgs, ExchangingPeerInfo, Stopped, }; std::atomic state_{State::Ready}; Role role_; Data pswd_; PeerInfo peer_info_; Data cert_; Data priv_key_; // Peer's info PeerInfo their_info_; ResultCallback cb_; void* opaque_ = nullptr; std::unique_ptr tls_; PairingAuthPtr auth_; unique_fd fd_; std::thread thread_; static constexpr size_t kExportedKeySize = 64; }; // PairingConnectionCtx PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, const Data& cert, const Data& priv_key) : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) { CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); } PairingConnectionCtx::~PairingConnectionCtx() { // Force close the fd and wait for the worker thread to finish. fd_.reset(); if (thread_.joinable()) { thread_.join(); } } bool PairingConnectionCtx::SetupTlsConnection() { tls_ = tls::TlsConnection::Create( role_ == Role::Server ? tls::TlsConnection::Role::Server : tls::TlsConnection::Role::Client, std::string_view(reinterpret_cast(cert_.data()), cert_.size()), std::string_view(reinterpret_cast(priv_key_.data()), priv_key_.size()), fd_); if (tls_ == nullptr) { LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get(); return false; } // Allow any peer certificate tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; }); // SSL doesn't seem to behave correctly with fdevents so just do a blocking // read for the pairing data. if (tls_->DoHandshake() != TlsError::Success) { LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get(); return false; } // To ensure the connection is not stolen while we do the PAKE, append the // exported key material from the tls connection to the password. std::vector exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize); if (exportedKeyMaterial.empty()) { LOG(ERROR) << "Failed to export key material"; return false; } pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()), std::make_move_iterator(exportedKeyMaterial.end())); auth_ = CreatePairingAuthPtr(role_, pswd_); return true; } bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header, std::string_view payload) { PairingPacketHeader network_header = *header; network_header.payload = htonl(network_header.payload); if (!tls_->WriteFully(std::string_view(reinterpret_cast(&network_header), sizeof(PairingPacketHeader))) || !tls_->WriteFully(payload)) { LOG(ERROR) << "Failed to write out PairingPacketHeader"; state_ = State::Stopped; return false; } return true; } bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) { auto data = tls_->ReadFully(sizeof(PairingPacketHeader)); if (data.empty()) { return false; } uint8_t* p = data.data(); // First byte is always PairingPacketHeader version header->version = *p; ++p; if (header->version < kMinSupportedKeyHeaderVersion || header->version > kMaxSupportedKeyHeaderVersion) { LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion << " them=" << header->version << ")"; return false; } // Next byte is the PairingPacket::Type if (!adb::proto::PairingPacket::Type_IsValid(*p)) { LOG(ERROR) << "Unknown PairingPacket type=" << static_cast(*p); return false; } header->type = *p; ++p; // Last, the payload size header->payload = ntohl(*(reinterpret_cast(p))); if (header->payload == 0 || header->payload > kMaxPayloadSize) { LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload << ")"; return false; } return true; } void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type, uint32_t payload_size) { header->version = kCurrentKeyHeaderVersion; uint8_t type8 = static_cast(static_cast(type)); header->type = type8; header->payload = payload_size; } bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type, uint8_t actual) { uint8_t expected = *reinterpret_cast(&expected_type); if (actual != expected) { LOG(ERROR) << "Unexpected header type (expected=" << static_cast(expected) << " actual=" << static_cast(actual) << ")"; return false; } return true; } void PairingConnectionCtx::NotifyResult(const PeerInfo* p) { cb_(p, fd_.get(), opaque_); state_ = State::Stopped; } bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) { if (fd < 0) { return false; } fd_.reset(fd); State expected = State::Ready; if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) { return false; } cb_ = cb; opaque_ = opaque; thread_ = std::thread([this] { StartWorker(); }); return true; } bool PairingConnectionCtx::DoExchangeMsgs() { uint32_t payload = pairing_auth_msg_size(auth_.get()); std::vector msg(payload); pairing_auth_get_spake2_msg(auth_.get(), msg.data()); PairingPacketHeader header; CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload); // Write our SPAKE2 msg if (!WriteHeader(&header, std::string_view(reinterpret_cast(msg.data()), msg.size()))) { LOG(ERROR) << "Failed to write SPAKE2 msg."; return false; } // Read the peer's SPAKE2 msg header if (!ReadHeader(&header)) { LOG(ERROR) << "Invalid PairingPacketHeader."; return false; } if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) { return false; } // Read the SPAKE2 msg payload and initialize the cipher for // encrypting the PeerInfo and certificate. auto their_msg = tls_->ReadFully(header.payload); if (their_msg.empty() || !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) { LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size() << "]"; return false; } return true; } bool PairingConnectionCtx::DoExchangePeerInfo() { // Encrypt PeerInfo std::vector buf; uint8_t* p = reinterpret_cast(&peer_info_); buf.assign(p, p + sizeof(peer_info_)); std::vector outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size())); CHECK(!outbuf.empty()); size_t outsize; if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { LOG(ERROR) << "Failed to encrypt peer info"; return false; } outbuf.resize(outsize); // Write out the packet header PairingPacketHeader out_header; out_header.version = kCurrentKeyHeaderVersion; out_header.type = static_cast(static_cast(adb::proto::PairingPacket::PEER_INFO)); out_header.payload = htonl(outbuf.size()); if (!tls_->WriteFully( std::string_view(reinterpret_cast(&out_header), sizeof(out_header)))) { LOG(ERROR) << "Unable to write PairingPacketHeader"; return false; } // Write out the encrypted payload if (!tls_->WriteFully( std::string_view(reinterpret_cast(outbuf.data()), outbuf.size()))) { LOG(ERROR) << "Unable to write encrypted peer info"; return false; } // Read in the peer's packet header PairingPacketHeader header; if (!ReadHeader(&header)) { LOG(ERROR) << "Invalid PairingPacketHeader."; return false; } if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) { return false; } // Read in the encrypted peer certificate buf = tls_->ReadFully(header.payload); if (buf.empty()) { return false; } // Try to decrypt the certificate outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size())); if (outbuf.empty()) { LOG(ERROR) << "Unsupported payload while decrypting peer info."; return false; } if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { LOG(ERROR) << "Failed to decrypt"; return false; } outbuf.resize(outsize); // The decrypted message should contain the PeerInfo. if (outbuf.size() != sizeof(PeerInfo)) { LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo); return false; } p = outbuf.data(); ::memcpy(&their_info_, p, sizeof(PeerInfo)); p += sizeof(PeerInfo); return true; } void PairingConnectionCtx::StartWorker() { // Setup the secure transport if (!SetupTlsConnection()) { NotifyResult(nullptr); return; } for (;;) { switch (state_) { case State::ExchangingMsgs: if (!DoExchangeMsgs()) { NotifyResult(nullptr); return; } state_ = State::ExchangingPeerInfo; break; case State::ExchangingPeerInfo: if (!DoExchangePeerInfo()) { NotifyResult(nullptr); return; } NotifyResult(&their_info_); return; case State::Ready: case State::Stopped: LOG(FATAL) << __func__ << ": Got invalid state"; return; } } } // static PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) { switch (role) { case Role::Client: return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size())); break; case Role::Server: return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size())); break; } } static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd, size_t pswd_len, const PeerInfo* peer_info, const uint8_t* x509_cert_pem, size_t x509_size, const uint8_t* priv_key_pem, size_t priv_size) { CHECK(pswd); CHECK_GT(pswd_len, 0U); CHECK(x509_cert_pem); CHECK_GT(x509_size, 0U); CHECK(priv_key_pem); CHECK_GT(priv_size, 0U); CHECK(peer_info); std::vector vec_pswd(pswd, pswd + pswd_len); std::vector vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); std::vector vec_priv_key(priv_key_pem, priv_key_pem + priv_size); return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key); } PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len, const PeerInfo* peer_info, const uint8_t* x509_cert_pem, size_t x509_size, const uint8_t* priv_key_pem, size_t priv_size) { return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info, x509_cert_pem, x509_size, priv_key_pem, priv_size); } PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len, const PeerInfo* peer_info, const uint8_t* x509_cert_pem, size_t x509_size, const uint8_t* priv_key_pem, size_t priv_size) { return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info, x509_cert_pem, x509_size, priv_key_pem, priv_size); } void pairing_connection_destroy(PairingConnectionCtx* ctx) { CHECK(ctx); delete ctx; } bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb, void* opaque) { return ctx->Start(fd, cb, opaque); }