/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "adb/pairing/pairing_server.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "internal/constants.h" using android::base::ScopedLockAssertion; using android::base::unique_fd; using namespace adb::crypto; using namespace adb::pairing; // The implementation has two background threads running: one to handle and // accept any new pairing connection requests (socket accept), and the other to // handle connection events (connection started, connection finished). struct PairingServerCtx { public: using Data = std::vector; virtual ~PairingServerCtx(); // All parameters must be non-empty. explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, const Data& priv_key, uint16_t port); // Starts the pairing server. This call is non-blocking. Upon completion, // if the pairing was successful, then |cb| will be called with the PublicKeyHeader // containing the info of the trusted peer. Otherwise, |cb| will be // called with an empty value. Start can only be called once in the lifetime // of this object. // // Returns the port number if PairingServerCtx was successfully started. Otherwise, // returns 0. uint16_t Start(pairing_server_result_cb cb, void* opaque); private: // Setup the server socket to accept incoming connections. Returns the // server port number (> 0 on success). uint16_t SetupServer(); // Force stop the server thread. void StopServer(); // handles a new pairing client connection bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_); // ======== connection events thread ============= std::mutex conn_mutex_; std::condition_variable conn_cv_; using FdVal = int; struct ConnectionDeleter { void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); } }; using ConnectionPtr = std::unique_ptr; static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, const Data& cert, const Data& priv_key); using NewConnectionEvent = std::tuple; // using ConnectionFinishedEvent = std::tuple>; using ConnectionEvent = std::variant; // Queue for connections to write into. We have a separate queue to read // from, in order to minimize the time the server thread is blocked. std::deque conn_write_queue_ GUARDED_BY(conn_mutex_); std::deque conn_read_queue_; // Map of fds to their PairingConnections currently running. std::unordered_map connections_; // Two threads launched when starting the pairing server: // 1) A server thread that waits for incoming client connections, and // 2) A connection events thread that synchonizes events from all of the // clients, since each PairingConnection is running in it's own thread. void StartConnectionEventsThread(); void StartServerThread(); static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque); std::thread conn_events_thread_; void ConnectionEventsWorker(); std::thread server_thread_; void ServerWorker(); bool is_terminate_ GUARDED_BY(conn_mutex_) = false; enum class State { Ready, Running, Stopped, }; State state_ = State::Ready; Data pswd_; PeerInfo peer_info_; Data cert_; Data priv_key_; uint16_t port_; pairing_server_result_cb cb_; void* opaque_ = nullptr; bool got_valid_pairing_ = false; static const int kEpollConstSocket = 0; // Used to break the server thread from epoll_wait static const int kEpollConstEventFd = 1; unique_fd epoll_fd_; unique_fd server_fd_; unique_fd event_fd_; }; // PairingServerCtx // static PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, const Data& cert, const Data& priv_key) { return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info, cert.data(), cert.size(), priv_key.data(), priv_key.size())); } PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, const Data& priv_key, uint16_t port) : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); } PairingServerCtx::~PairingServerCtx() { // Since these connections have references to us, let's make sure they // destruct before us. if (server_thread_.joinable()) { StopServer(); server_thread_.join(); } { std::lock_guard lock(conn_mutex_); is_terminate_ = true; } conn_cv_.notify_one(); if (conn_events_thread_.joinable()) { conn_events_thread_.join(); } // Notify the cb_ if it hasn't already. if (!got_valid_pairing_ && cb_ != nullptr) { cb_(nullptr, opaque_); } } uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) { cb_ = cb; opaque_ = opaque; if (state_ != State::Ready) { LOG(ERROR) << "PairingServerCtx already running or stopped"; return 0; } port_ = SetupServer(); if (port_ == 0) { LOG(ERROR) << "Unable to start PairingServer"; state_ = State::Stopped; return 0; } LOG(INFO) << "Pairing server started on port " << port_; state_ = State::Running; return port_; } void PairingServerCtx::StopServer() { if (event_fd_.get() == -1) { return; } uint64_t value = 1; ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); if (rc == -1) { // This can happen if the server didn't start. PLOG(ERROR) << "write to eventfd failed"; } else if (rc != sizeof(value)) { LOG(FATAL) << "write to event returned short (" << rc << ")"; } } uint16_t PairingServerCtx::SetupServer() { epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); if (epoll_fd_ == -1) { PLOG(ERROR) << "failed to create epoll fd"; return 0; } event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); if (event_fd_ == -1) { PLOG(ERROR) << "failed to create eventfd"; return 0; } server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); if (server_fd_.get() == -1) { PLOG(ERROR) << "Failed to start pairing connection server"; return 0; } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) { PLOG(ERROR) << "Failed to make server socket cloexec"; return 0; } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) { PLOG(ERROR) << "Failed to make server socket nonblocking"; return 0; } StartConnectionEventsThread(); StartServerThread(); int port = socket_get_local_port(server_fd_.get()); return (port <= 0 ? 0 : port); } void PairingServerCtx::StartServerThread() { server_thread_ = std::thread([this]() { ServerWorker(); }); } void PairingServerCtx::StartConnectionEventsThread() { conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); }); } void PairingServerCtx::ServerWorker() { { struct epoll_event event; event.events = EPOLLIN; event.data.u64 = kEpollConstSocket; CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); } { struct epoll_event event; event.events = EPOLLIN; event.data.u64 = kEpollConstEventFd; CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); } while (true) { struct epoll_event events[2]; int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); if (rc == -1) { PLOG(ERROR) << "epoll_wait failed"; return; } else if (rc == 0) { LOG(ERROR) << "epoll_wait returned 0"; return; } for (int i = 0; i < rc; ++i) { struct epoll_event& event = events[i]; switch (event.data.u64) { case kEpollConstSocket: HandleNewClientConnection(server_fd_.get()); break; case kEpollConstEventFd: uint64_t dummy; int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); if (rc != sizeof(dummy)) { PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; } return; } } } } // static void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) { auto* p = reinterpret_cast(opaque); ConnectionFinishedEvent event; if (peer_info != nullptr) { if (peer_info->type == ADB_RSA_PUB_KEY) { event = std::make_tuple(fd, peer_info->type, std::string(reinterpret_cast(peer_info->data))); } else { LOG(WARNING) << "Ignoring successful pairing because of unknown " << "PeerInfo type=" << peer_info->type; } } else { event = std::make_tuple(fd, 0, std::nullopt); } { std::lock_guard lock(p->conn_mutex_); p->conn_write_queue_.push_back(std::move(event)); } p->conn_cv_.notify_one(); } void PairingServerCtx::ConnectionEventsWorker() { uint8_t num_tries = 0; for (;;) { // Transfer the write queue to the read queue. { std::unique_lock lock(conn_mutex_); ScopedLockAssertion assume_locked(conn_mutex_); if (is_terminate_) { // We check |is_terminate_| twice because condition_variable's // notify() only wakes up a thread if it is in the wait state // prior to notify(). Furthermore, we aren't holding the mutex // when processing the events in |conn_read_queue_|. return; } if (conn_write_queue_.empty()) { // We need to wait for new events, or the termination signal. conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { return (is_terminate_ || !conn_write_queue_.empty()); }); } if (is_terminate_) { // We're done. return; } // Move all events into the read queue. conn_read_queue_ = std::move(conn_write_queue_); conn_write_queue_.clear(); } // Process all events in the read queue. while (conn_read_queue_.size() > 0) { auto& event = conn_read_queue_.front(); if (auto* p = std::get_if(&event)) { // Ignore if we are already at the max number of connections if (connections_.size() >= internal::kMaxConnections) { conn_read_queue_.pop_front(); continue; } auto [ufd, connection] = std::move(*p); int fd = ufd.release(); bool started = pairing_connection_start(connection.get(), fd, PairingConnectionCallback, this); if (!started) { LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; ufd.reset(fd); } else { connections_[fd] = std::move(connection); } } else if (auto* p = std::get_if(&event)) { auto [fd, info_type, public_key] = std::move(*p); if (public_key.has_value() && !public_key->empty()) { // Valid pairing. Let's shutdown the server and close any // pairing connections in progress. StopServer(); connections_.clear(); PeerInfo info = {}; info.type = info_type; strncpy(reinterpret_cast(info.data), public_key->data(), public_key->size()); cb_(&info, opaque_); got_valid_pairing_ = true; return; } // Invalid pairing. Close the invalid connection. if (connections_.find(fd) != connections_.end()) { connections_.erase(fd); } if (++num_tries >= internal::kMaxPairingAttempts) { cb_(nullptr, opaque_); // To prevent the destructor from calling it again. cb_ = nullptr; return; } } conn_read_queue_.pop_front(); } } } bool PairingServerCtx::HandleNewClientConnection(int fd) { unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); if (ufd == -1) { PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; return false; } auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_); if (connection == nullptr) { LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; return false; } // send the new connection to the connection thread for further processing NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); { std::lock_guard lock(conn_mutex_); conn_write_queue_.push_back(std::move(event)); } conn_cv_.notify_one(); return true; } uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) { return ctx->Start(cb, opaque); } PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len, const PeerInfo* peer_info, const uint8_t* x509_cert_pem, size_t x509_size, const uint8_t* priv_key_pem, size_t priv_size, uint16_t port) { CHECK(pswd); CHECK_GT(pswd_len, 0U); CHECK(x509_cert_pem); CHECK_GT(x509_size, 0U); CHECK(priv_key_pem); CHECK_GT(priv_size, 0U); CHECK(peer_info); std::vector vec_pswd(pswd, pswd + pswd_len); std::vector vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); std::vector vec_priv_key(priv_key_pem, priv_key_pem + priv_size); return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port); } PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len, const PeerInfo* peer_info, uint16_t port) { auto rsa_2048 = CreateRSA2048Key(); auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey()); std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey()); std::string cert_pem = X509ToPEMString(x509_cert.get()); return pairing_server_new(pswd, pswd_len, peer_info, reinterpret_cast(cert_pem.data()), cert_pem.size(), reinterpret_cast(pkey_pem.data()), pkey_pem.size(), port); } void pairing_server_destroy(PairingServerCtx* ctx) { CHECK(ctx); delete ctx; }