You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
236 lines
7.9 KiB
236 lines
7.9 KiB
4 months ago
|
// Copyright 2019 The Chromium Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style license that can be
|
||
|
// found in the LICENSE file.
|
||
|
|
||
|
#include "cast/common/channel/virtual_connection_router.h"
|
||
|
|
||
|
#include <utility>
|
||
|
|
||
|
#include "cast/common/channel/cast_message_handler.h"
|
||
|
#include "cast/common/channel/connection_namespace_handler.h"
|
||
|
#include "cast/common/channel/message_util.h"
|
||
|
#include "cast/common/channel/proto/cast_channel.pb.h"
|
||
|
#include "util/osp_logging.h"
|
||
|
|
||
|
namespace openscreen {
|
||
|
namespace cast {
|
||
|
|
||
|
using ::cast::channel::CastMessage;
|
||
|
|
||
|
VirtualConnectionRouter::VirtualConnectionRouter() = default;
|
||
|
|
||
|
VirtualConnectionRouter::~VirtualConnectionRouter() = default;
|
||
|
|
||
|
void VirtualConnectionRouter::AddConnection(
|
||
|
VirtualConnection virtual_connection,
|
||
|
VirtualConnection::AssociatedData associated_data) {
|
||
|
auto& socket_map = connections_[virtual_connection.socket_id];
|
||
|
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
|
||
|
auto it = std::find_if(
|
||
|
local_entries.first, local_entries.second,
|
||
|
[&virtual_connection](const std::pair<std::string, VCTail>& entry) {
|
||
|
return entry.second.peer_id == virtual_connection.peer_id;
|
||
|
});
|
||
|
if (it == socket_map.end()) {
|
||
|
socket_map.emplace(std::move(virtual_connection.local_id),
|
||
|
VCTail{std::move(virtual_connection.peer_id),
|
||
|
std::move(associated_data)});
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool VirtualConnectionRouter::RemoveConnection(
|
||
|
const VirtualConnection& virtual_connection,
|
||
|
VirtualConnection::CloseReason reason) {
|
||
|
auto socket_entry = connections_.find(virtual_connection.socket_id);
|
||
|
if (socket_entry == connections_.end()) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
auto& socket_map = socket_entry->second;
|
||
|
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
|
||
|
if (local_entries.first == socket_map.end()) {
|
||
|
return false;
|
||
|
}
|
||
|
for (auto it = local_entries.first; it != local_entries.second; ++it) {
|
||
|
if (it->second.peer_id == virtual_connection.peer_id) {
|
||
|
socket_map.erase(it);
|
||
|
if (socket_map.empty()) {
|
||
|
connections_.erase(socket_entry);
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::RemoveConnectionsByLocalId(
|
||
|
const std::string& local_id) {
|
||
|
for (auto socket_entry = connections_.begin();
|
||
|
socket_entry != connections_.end();) {
|
||
|
auto& socket_map = socket_entry->second;
|
||
|
auto local_entries = socket_map.equal_range(local_id);
|
||
|
if (local_entries.first != socket_map.end()) {
|
||
|
socket_map.erase(local_entries.first, local_entries.second);
|
||
|
if (socket_map.empty()) {
|
||
|
socket_entry = connections_.erase(socket_entry);
|
||
|
continue;
|
||
|
}
|
||
|
}
|
||
|
++socket_entry;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::RemoveConnectionsBySocketId(int socket_id) {
|
||
|
auto entry = connections_.find(socket_id);
|
||
|
if (entry != connections_.end()) {
|
||
|
connections_.erase(entry);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
absl::optional<const VirtualConnection::AssociatedData*>
|
||
|
VirtualConnectionRouter::GetConnectionData(
|
||
|
const VirtualConnection& virtual_connection) const {
|
||
|
auto socket_entry = connections_.find(virtual_connection.socket_id);
|
||
|
if (socket_entry == connections_.end()) {
|
||
|
return absl::nullopt;
|
||
|
}
|
||
|
|
||
|
auto& socket_map = socket_entry->second;
|
||
|
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
|
||
|
if (local_entries.first == socket_map.end()) {
|
||
|
return absl::nullopt;
|
||
|
}
|
||
|
for (auto it = local_entries.first; it != local_entries.second; ++it) {
|
||
|
if (it->second.peer_id == virtual_connection.peer_id) {
|
||
|
return &it->second.data;
|
||
|
}
|
||
|
}
|
||
|
return absl::nullopt;
|
||
|
}
|
||
|
|
||
|
bool VirtualConnectionRouter::AddHandlerForLocalId(
|
||
|
std::string local_id,
|
||
|
CastMessageHandler* endpoint) {
|
||
|
return endpoints_.emplace(std::move(local_id), endpoint).second;
|
||
|
}
|
||
|
|
||
|
bool VirtualConnectionRouter::RemoveHandlerForLocalId(
|
||
|
const std::string& local_id) {
|
||
|
return endpoints_.erase(local_id) == 1u;
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
|
||
|
std::unique_ptr<CastSocket> socket) {
|
||
|
int id = socket->socket_id();
|
||
|
socket->SetClient(this);
|
||
|
sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::CloseSocket(int id) {
|
||
|
auto it = sockets_.find(id);
|
||
|
if (it != sockets_.end()) {
|
||
|
RemoveConnectionsBySocketId(id);
|
||
|
std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
|
||
|
SocketErrorHandler* error_handler = it->second.error_handler;
|
||
|
sockets_.erase(it);
|
||
|
error_handler->OnClose(socket.get());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
|
||
|
CastMessage message) {
|
||
|
if (virtual_conn.peer_id == kBroadcastId) {
|
||
|
return BroadcastFromLocalPeer(std::move(virtual_conn.local_id),
|
||
|
std::move(message));
|
||
|
}
|
||
|
|
||
|
if (!IsTransportNamespace(message.namespace_()) &&
|
||
|
!GetConnectionData(virtual_conn)) {
|
||
|
return Error::Code::kNoActiveConnection;
|
||
|
}
|
||
|
auto it = sockets_.find(virtual_conn.socket_id);
|
||
|
if (it == sockets_.end()) {
|
||
|
return Error::Code::kItemNotFound;
|
||
|
}
|
||
|
message.set_source_id(std::move(virtual_conn.local_id));
|
||
|
message.set_destination_id(std::move(virtual_conn.peer_id));
|
||
|
return it->second.socket->Send(message);
|
||
|
}
|
||
|
|
||
|
Error VirtualConnectionRouter::BroadcastFromLocalPeer(
|
||
|
std::string local_id,
|
||
|
::cast::channel::CastMessage message) {
|
||
|
message.set_source_id(std::move(local_id));
|
||
|
message.set_destination_id(kBroadcastId);
|
||
|
|
||
|
// Broadcast to local endpoints.
|
||
|
for (const auto& entry : endpoints_) {
|
||
|
if (entry.first != message.source_id()) {
|
||
|
entry.second->OnMessage(this, nullptr, message);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Broadcast to remote endpoints. If an Error occurs, continue broadcasting,
|
||
|
// and later return the first Error that occurred.
|
||
|
Error error;
|
||
|
for (const auto& entry : sockets_) {
|
||
|
auto result = entry.second.socket->Send(message);
|
||
|
if (!result.ok() && error.ok()) {
|
||
|
error = std::move(result);
|
||
|
}
|
||
|
}
|
||
|
return error;
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
|
||
|
const int id = socket->socket_id();
|
||
|
auto it = sockets_.find(id);
|
||
|
if (it != sockets_.end()) {
|
||
|
RemoveConnectionsBySocketId(id);
|
||
|
std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
|
||
|
SocketErrorHandler* error_handler = it->second.error_handler;
|
||
|
sockets_.erase(it);
|
||
|
error_handler->OnError(socket, error);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void VirtualConnectionRouter::OnMessage(CastSocket* socket,
|
||
|
CastMessage message) {
|
||
|
OSP_DCHECK(socket);
|
||
|
|
||
|
const std::string& local_id = message.destination_id();
|
||
|
if (local_id == kBroadcastId) {
|
||
|
for (const auto& entry : endpoints_) {
|
||
|
entry.second->OnMessage(this, socket, message);
|
||
|
}
|
||
|
} else {
|
||
|
// Connection namespace messages are weird: The message.source_id() and
|
||
|
// message.destination_id() are NOT treated as "envelope routing
|
||
|
// information," like for all other namespaces. Instead, they are considered
|
||
|
// part of the payload data for CONNECT/CLOSE requests. Thus, they require
|
||
|
// special-case handling here.
|
||
|
if (message.namespace_() == kConnectionNamespace) {
|
||
|
if (connection_handler_) {
|
||
|
connection_handler_->OnMessage(this, socket, std::move(message));
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// Drop all messages for virtual connections that do not yet exist.
|
||
|
// Exception: All transport namespace messages (e.g., device auth,
|
||
|
// heartbeats, etc.); because these are always assumed to have a route.
|
||
|
if (!IsTransportNamespace(message.namespace_()) &&
|
||
|
!GetConnectionData(VirtualConnection{local_id, message.source_id(),
|
||
|
socket->socket_id()})) {
|
||
|
return;
|
||
|
}
|
||
|
auto it = endpoints_.find(local_id);
|
||
|
if (it != endpoints_.end()) {
|
||
|
it->second->OnMessage(this, socket, std::move(message));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
} // namespace cast
|
||
|
} // namespace openscreen
|