You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

254 lines
8.4 KiB

// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "osp/impl/quic/quic_client.h"
#include <algorithm>
#include <functional>
#include <memory>
#include "platform/api/task_runner.h"
#include "platform/api/time.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace osp {
QuicClient::QuicClient(
MessageDemuxer* demuxer,
std::unique_ptr<QuicConnectionFactory> connection_factory,
ProtocolConnectionServiceObserver* observer,
ClockNowFunctionPtr now_function,
TaskRunner* task_runner)
: ProtocolConnectionClient(demuxer, observer),
connection_factory_(std::move(connection_factory)),
cleanup_alarm_(now_function, task_runner) {}
QuicClient::~QuicClient() {
CloseAllConnections();
}
bool QuicClient::Start() {
if (state_ == State::kRunning)
return false;
state_ = State::kRunning;
Cleanup(); // Start periodic clean-ups.
observer_->OnRunning();
return true;
}
bool QuicClient::Stop() {
if (state_ == State::kStopped)
return false;
CloseAllConnections();
state_ = State::kStopped;
Cleanup(); // Final clean-up.
observer_->OnStopped();
return true;
}
void QuicClient::Cleanup() {
for (auto& entry : connections_) {
entry.second.delegate->DestroyClosedStreams();
if (!entry.second.delegate->has_streams())
entry.second.connection->Close();
}
for (uint64_t endpoint_id : delete_connections_) {
auto it = connections_.find(endpoint_id);
if (it != connections_.end()) {
connections_.erase(it);
}
}
delete_connections_.clear();
constexpr Clock::duration kQuicCleanupPeriod = std::chrono::milliseconds(500);
if (state_ != State::kStopped) {
cleanup_alarm_.ScheduleFromNow([this] { Cleanup(); }, kQuicCleanupPeriod);
}
}
QuicClient::ConnectRequest QuicClient::Connect(
const IPEndpoint& endpoint,
ConnectionRequestCallback* request) {
if (state_ != State::kRunning)
return ConnectRequest(this, 0);
auto endpoint_entry = endpoint_map_.find(endpoint);
if (endpoint_entry != endpoint_map_.end()) {
auto immediate_result = CreateProtocolConnection(endpoint_entry->second);
OSP_DCHECK(immediate_result);
request->OnConnectionOpened(0, std::move(immediate_result));
return ConnectRequest(this, 0);
}
return CreatePendingConnection(endpoint, request);
}
std::unique_ptr<ProtocolConnection> QuicClient::CreateProtocolConnection(
uint64_t endpoint_id) {
if (state_ != State::kRunning)
return nullptr;
auto connection_entry = connections_.find(endpoint_id);
if (connection_entry == connections_.end())
return nullptr;
return QuicProtocolConnection::FromExisting(
this, connection_entry->second.connection.get(),
connection_entry->second.delegate.get(), endpoint_id);
}
void QuicClient::OnConnectionDestroyed(QuicProtocolConnection* connection) {
if (!connection->stream())
return;
auto connection_entry = connections_.find(connection->endpoint_id());
if (connection_entry == connections_.end())
return;
connection_entry->second.delegate->DropProtocolConnection(connection);
}
uint64_t QuicClient::OnCryptoHandshakeComplete(
ServiceConnectionDelegate* delegate,
uint64_t connection_id) {
const IPEndpoint& endpoint = delegate->endpoint();
auto pending_entry = pending_connections_.find(endpoint);
if (pending_entry == pending_connections_.end())
return 0;
ServiceConnectionData connection_data = std::move(pending_entry->second.data);
auto* connection = connection_data.connection.get();
uint64_t endpoint_id = next_endpoint_id_++;
endpoint_map_[endpoint] = endpoint_id;
connections_.emplace(endpoint_id, std::move(connection_data));
for (auto& request : pending_entry->second.callbacks) {
request_map_.erase(request.first);
std::unique_ptr<QuicProtocolConnection> pc =
QuicProtocolConnection::FromExisting(this, connection, delegate,
endpoint_id);
request_map_.erase(request.first);
request.second->OnConnectionOpened(request.first, std::move(pc));
}
pending_connections_.erase(pending_entry);
return endpoint_id;
}
void QuicClient::OnIncomingStream(
std::unique_ptr<QuicProtocolConnection> connection) {
// TODO(jophba): Change to just use OnIncomingConnection when the observer
// is properly set up.
connection->CloseWriteEnd();
connection.reset();
}
void QuicClient::OnConnectionClosed(uint64_t endpoint_id,
uint64_t connection_id) {
// TODO(btolsch): Is this how handshake failure is communicated to the
// delegate?
auto connection_entry = connections_.find(endpoint_id);
if (connection_entry == connections_.end())
return;
delete_connections_.push_back(endpoint_id);
// TODO(crbug.com/openscreen/42): If we reset request IDs when a connection is
// closed, we might end up re-using request IDs when a new connection is
// created to the same endpoint.
endpoint_request_ids_.ResetRequestId(endpoint_id);
}
void QuicClient::OnDataReceived(uint64_t endpoint_id,
uint64_t connection_id,
const uint8_t* data,
size_t data_size) {
demuxer_->OnStreamData(endpoint_id, connection_id, data, data_size);
}
QuicClient::PendingConnectionData::PendingConnectionData(
ServiceConnectionData&& data)
: data(std::move(data)) {}
QuicClient::PendingConnectionData::PendingConnectionData(
PendingConnectionData&&) noexcept = default;
QuicClient::PendingConnectionData::~PendingConnectionData() = default;
QuicClient::PendingConnectionData& QuicClient::PendingConnectionData::operator=(
PendingConnectionData&&) noexcept = default;
QuicClient::ConnectRequest QuicClient::CreatePendingConnection(
const IPEndpoint& endpoint,
ConnectionRequestCallback* request) {
auto pending_entry = pending_connections_.find(endpoint);
if (pending_entry == pending_connections_.end()) {
uint64_t request_id = StartConnectionRequest(endpoint, request);
return ConnectRequest(this, request_id);
} else {
uint64_t request_id = next_request_id_++;
pending_entry->second.callbacks.emplace_back(request_id, request);
return ConnectRequest(this, request_id);
}
}
uint64_t QuicClient::StartConnectionRequest(
const IPEndpoint& endpoint,
ConnectionRequestCallback* request) {
auto delegate = std::make_unique<ServiceConnectionDelegate>(this, endpoint);
std::unique_ptr<QuicConnection> connection =
connection_factory_->Connect(endpoint, delegate.get());
if (!connection) {
// TODO(btolsch): Need interface/handling for Connect() failures. Or, should
// request->OnConnectionFailed() be called?
OSP_DCHECK(false)
<< __func__
<< ": Factory connect failed, but requestor will never know.";
return 0;
}
auto pending_result = pending_connections_.emplace(
endpoint, PendingConnectionData(ServiceConnectionData(
std::move(connection), std::move(delegate))));
uint64_t request_id = next_request_id_++;
pending_result.first->second.callbacks.emplace_back(request_id, request);
return request_id;
}
void QuicClient::CloseAllConnections() {
for (auto& conn : pending_connections_)
conn.second.data.connection->Close();
pending_connections_.clear();
for (auto& conn : connections_)
conn.second.connection->Close();
connections_.clear();
endpoint_map_.clear();
next_endpoint_id_ = 0;
endpoint_request_ids_.Reset();
for (auto& request : request_map_) {
request.second.second->OnConnectionFailed(request.first);
}
request_map_.clear();
}
void QuicClient::CancelConnectRequest(uint64_t request_id) {
auto request_entry = request_map_.find(request_id);
if (request_entry == request_map_.end())
return;
auto pending_entry = pending_connections_.find(request_entry->second.first);
if (pending_entry != pending_connections_.end()) {
auto& callbacks = pending_entry->second.callbacks;
callbacks.erase(
std::remove_if(
callbacks.begin(), callbacks.end(),
[request_id](const std::pair<uint64_t, ConnectionRequestCallback*>&
callback) {
return request_id == callback.first;
}),
callbacks.end());
if (callbacks.empty())
pending_connections_.erase(pending_entry);
}
request_map_.erase(request_entry);
}
} // namespace osp
} // namespace openscreen