You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

331 lines
12 KiB

// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "cast/common/channel/connection_namespace_handler.h"
#include <algorithm>
#include <string>
#include <type_traits>
#include <utility>
#include "absl/types/optional.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/virtual_connection.h"
#include "cast/common/channel/virtual_connection_router.h"
#include "cast/common/public/cast_socket.h"
#include "util/json/json_serialization.h"
#include "util/json/json_value.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace cast {
using ::cast::channel::CastMessage;
using ::cast::channel::CastMessage_PayloadType;
namespace {
bool IsValidProtocolVersion(int version) {
return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version);
}
absl::optional<int> FindMaxProtocolVersion(const Json::Value* version,
const Json::Value* version_list) {
using ArrayIndex = Json::Value::ArrayIndex;
static_assert(std::is_integral<ArrayIndex>::value,
"Assuming ArrayIndex is integral");
absl::optional<int> max_version;
if (version_list && version_list->isArray()) {
max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
for (auto it = version_list->begin(), end = version_list->end(); it != end;
++it) {
if (it->isInt()) {
int version_int = it->asInt();
if (IsValidProtocolVersion(version_int) && version_int > *max_version) {
max_version = version_int;
}
}
}
}
if (version && version->isInt()) {
int version_int = version->asInt();
if (IsValidProtocolVersion(version_int)) {
if (!max_version) {
max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0;
}
if (version_int > max_version) {
max_version = version_int;
}
}
}
return max_version;
}
VirtualConnection::CloseReason GetCloseReason(
const Json::Value& parsed_message) {
VirtualConnection::CloseReason reason =
VirtualConnection::CloseReason::kClosedByPeer;
absl::optional<int> reason_code = MaybeGetInt(
parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode));
if (reason_code) {
int code = reason_code.value();
if (code >= VirtualConnection::CloseReason::kFirstReason &&
code <= VirtualConnection::CloseReason::kLastReason) {
reason = static_cast<VirtualConnection::CloseReason>(code);
}
}
return reason;
}
} // namespace
ConnectionNamespaceHandler::ConnectionNamespaceHandler(
VirtualConnectionRouter* vc_router,
VirtualConnectionPolicy* vc_policy)
: vc_router_(vc_router), vc_policy_(vc_policy) {
OSP_DCHECK(vc_router_);
OSP_DCHECK(vc_policy_);
vc_router_->set_connection_namespace_handler(this);
}
ConnectionNamespaceHandler::~ConnectionNamespaceHandler() {
vc_router_->set_connection_namespace_handler(nullptr);
}
void ConnectionNamespaceHandler::OpenRemoteConnection(
VirtualConnection conn,
RemoteConnectionResultCallback result_callback) {
OSP_DCHECK(!vc_router_->GetConnectionData(conn));
OSP_DCHECK(std::none_of(
pending_remote_requests_.begin(), pending_remote_requests_.end(),
[&](const PendingRequest& request) { return request.conn == conn; }));
pending_remote_requests_.push_back({conn, std::move(result_callback)});
SendConnect(std::move(conn));
}
void ConnectionNamespaceHandler::CloseRemoteConnection(VirtualConnection conn) {
if (RemoveConnection(conn, VirtualConnection::kClosedBySelf)) {
SendClose(std::move(conn));
}
}
void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router,
CastSocket* socket,
CastMessage message) {
if (message.destination_id() == kBroadcastId ||
message.source_id() == kBroadcastId ||
message.payload_type() !=
CastMessage_PayloadType::CastMessage_PayloadType_STRING) {
return;
}
ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
if (result.is_error()) {
return;
}
Json::Value& value = result.value();
if (!value.isObject()) {
return;
}
absl::optional<absl::string_view> type =
MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
if (!type) {
// TODO(btolsch): Some of these paths should have error reporting. One
// possibility is to pass errors back through |router| so higher-level code
// can decide whether to show an error to the user, stop talking to a
// particular device, etc.
return;
}
absl::string_view type_str = type.value();
if (type_str == kMessageTypeConnect) {
HandleConnect(socket, std::move(message), std::move(value));
} else if (type_str == kMessageTypeClose) {
HandleClose(socket, std::move(message), std::move(value));
} else if (type_str == kMessageTypeConnected) {
HandleConnectedResponse(socket, std::move(message), std::move(value));
} else {
// NOTE: Unknown message type so ignore it.
// TODO(btolsch): Should be included in future error reporting.
}
}
void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket,
CastMessage message,
Json::Value parsed_message) {
if (message.destination_id() == kBroadcastId ||
message.source_id() == kBroadcastId) {
return;
}
VirtualConnection virtual_conn{std::move(message.destination_id()),
std::move(message.source_id()),
ToCastSocketId(socket)};
if (!vc_policy_->IsConnectionAllowed(virtual_conn)) {
SendClose(std::move(virtual_conn));
return;
}
absl::optional<int> maybe_conn_type = MaybeGetInt(
parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType));
VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong;
if (maybe_conn_type) {
int int_type = maybe_conn_type.value();
if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) ||
int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) {
SendClose(std::move(virtual_conn));
return;
}
conn_type = static_cast<VirtualConnection::Type>(int_type);
}
VirtualConnection::AssociatedData data;
data.type = conn_type;
absl::optional<absl::string_view> user_agent = MaybeGetString(
parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent));
if (user_agent) {
data.user_agent = std::string(user_agent.value());
}
const Json::Value* sender_info_value = parsed_message.find(
JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo));
if (!sender_info_value || !sender_info_value->isObject()) {
// TODO(btolsch): Should this be guessed from user agent?
OSP_DVLOG << "No sender info from protocol.";
}
const Json::Value* version_value = parsed_message.find(
JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
const Json::Value* version_list_value = parsed_message.find(
JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList));
absl::optional<int> negotiated_version =
FindMaxProtocolVersion(version_value, version_list_value);
if (negotiated_version) {
data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>(
negotiated_version.value());
} else {
data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0;
}
if (socket) {
data.ip_fragment = socket->GetSanitizedIpAddress();
} else {
data.ip_fragment = {};
}
OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", "
<< virtual_conn.peer_id << ", " << virtual_conn.socket_id;
// NOTE: Only send a response for senders that actually sent a version. This
// maintains compatibility with older senders that don't send a version and
// don't expect a response.
if (negotiated_version) {
SendConnectedResponse(virtual_conn, negotiated_version.value());
}
vc_router_->AddConnection(std::move(virtual_conn), std::move(data));
}
void ConnectionNamespaceHandler::HandleClose(CastSocket* socket,
CastMessage message,
Json::Value parsed_message) {
const VirtualConnection conn{std::move(*message.mutable_destination_id()),
std::move(*message.mutable_source_id()),
ToCastSocketId(socket)};
const auto reason = GetCloseReason(parsed_message);
if (RemoveConnection(conn, reason)) {
OSP_DVLOG << "Connection closed (reason: " << reason
<< "): " << conn.local_id << ", " << conn.peer_id << ", "
<< conn.socket_id;
}
}
void ConnectionNamespaceHandler::HandleConnectedResponse(
CastSocket* socket,
CastMessage message,
Json::Value parsed_message) {
const VirtualConnection conn{std::move(message.destination_id()),
std::move(message.source_id()),
ToCastSocketId(socket)};
const auto it = std::find_if(
pending_remote_requests_.begin(), pending_remote_requests_.end(),
[&](const PendingRequest& request) { return request.conn == conn; });
if (it == pending_remote_requests_.end()) {
return;
}
vc_router_->AddConnection(conn,
{VirtualConnection::Type::kStrong,
{},
{},
VirtualConnection::ProtocolVersion::kV2_1_3});
const auto callback = std::move(it->result_callback);
pending_remote_requests_.erase(it);
callback(true);
}
void ConnectionNamespaceHandler::SendConnect(VirtualConnection virtual_conn) {
::cast::channel::CastMessage message =
MakeConnectMessage(virtual_conn.local_id, virtual_conn.peer_id);
vc_router_->Send(std::move(virtual_conn), std::move(message));
}
void ConnectionNamespaceHandler::SendClose(VirtualConnection virtual_conn) {
::cast::channel::CastMessage message =
MakeCloseMessage(virtual_conn.local_id, virtual_conn.peer_id);
vc_router_->Send(std::move(virtual_conn), std::move(message));
}
void ConnectionNamespaceHandler::SendConnectedResponse(
const VirtualConnection& virtual_conn,
int max_protocol_version) {
Json::Value connected_message(Json::ValueType::objectValue);
connected_message[kMessageKeyType] = kMessageTypeConnected;
connected_message[kMessageKeyProtocolVersion] =
static_cast<int>(max_protocol_version);
ErrorOr<std::string> result = json::Stringify(connected_message);
if (result.is_error()) {
return;
}
vc_router_->Send(
virtual_conn,
MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value())));
}
bool ConnectionNamespaceHandler::RemoveConnection(
const VirtualConnection& conn,
VirtualConnection::CloseReason reason) {
bool found_connection = false;
if (vc_router_->GetConnectionData(conn)) {
vc_router_->RemoveConnection(conn, reason);
found_connection = true;
}
// Cancel pending remote request, if any.
const auto it = std::find_if(
pending_remote_requests_.begin(), pending_remote_requests_.end(),
[&](const PendingRequest& request) { return request.conn == conn; });
if (it != pending_remote_requests_.end()) {
const auto callback = std::move(it->result_callback);
pending_remote_requests_.erase(it);
callback(false);
found_connection = true;
}
return found_connection;
}
} // namespace cast
} // namespace openscreen