You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
11 KiB
306 lines
11 KiB
// Copyright 2020 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include <openssl/evp.h>
|
|
#include <openssl/mem.h>
|
|
|
|
#include <atomic>
|
|
#include <chrono>
|
|
|
|
#include "cast/common/certificate/cast_trust_store.h"
|
|
#include "cast/common/certificate/testing/test_helpers.h"
|
|
#include "cast/common/channel/connection_namespace_handler.h"
|
|
#include "cast/common/channel/message_util.h"
|
|
#include "cast/common/channel/virtual_connection_router.h"
|
|
#include "cast/common/public/cast_socket.h"
|
|
#include "cast/receiver/channel/device_auth_namespace_handler.h"
|
|
#include "cast/receiver/channel/static_credentials.h"
|
|
#include "cast/receiver/public/receiver_socket_factory.h"
|
|
#include "cast/sender/public/sender_socket_factory.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
#include "platform/api/serial_delete_ptr.h"
|
|
#include "platform/api/tls_connection_factory.h"
|
|
#include "platform/base/tls_connect_options.h"
|
|
#include "platform/base/tls_credentials.h"
|
|
#include "platform/base/tls_listen_options.h"
|
|
#include "platform/impl/logging.h"
|
|
#include "platform/impl/network_interface.h"
|
|
#include "platform/impl/platform_client_posix.h"
|
|
#include "testing/util/task_util.h"
|
|
#include "util/crypto/certificate_utils.h"
|
|
#include "util/osp_logging.h"
|
|
|
|
namespace openscreen {
|
|
namespace cast {
|
|
namespace {
|
|
|
|
using ::testing::_;
|
|
using ::testing::StrictMock;
|
|
|
|
constexpr char kLogDecorator[] = "--- ";
|
|
|
|
} // namespace
|
|
|
|
class SenderSocketsClient : public SenderSocketFactory::Client,
|
|
public VirtualConnectionRouter::SocketErrorHandler {
|
|
public:
|
|
explicit SenderSocketsClient(VirtualConnectionRouter* router) // NOLINT
|
|
: router_(router) {}
|
|
virtual ~SenderSocketsClient() = default;
|
|
|
|
CastSocket* socket() const { return socket_; }
|
|
|
|
// SenderSocketFactory::Client overrides.
|
|
void OnConnected(SenderSocketFactory* factory,
|
|
const IPEndpoint& endpoint,
|
|
std::unique_ptr<CastSocket> socket) {
|
|
OSP_CHECK(!socket_);
|
|
OSP_LOG_INFO << kLogDecorator
|
|
<< "Sender connected to endpoint: " << endpoint;
|
|
socket_ = socket.get();
|
|
router_->TakeSocket(this, std::move(socket));
|
|
}
|
|
|
|
void OnError(SenderSocketFactory* factory,
|
|
const IPEndpoint& endpoint,
|
|
Error error) override {
|
|
OSP_LOG_FATAL << error;
|
|
}
|
|
|
|
// VirtualConnectionRouter::SocketErrorHandler overrides.
|
|
void OnClose(CastSocket* socket) override {
|
|
socket_ = nullptr;
|
|
OnCloseMock(socket);
|
|
}
|
|
void OnError(CastSocket* socket, Error error) override {
|
|
socket_ = nullptr;
|
|
OnErrorMock(socket, std::move(error));
|
|
}
|
|
|
|
MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
|
|
MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
|
|
|
|
private:
|
|
VirtualConnectionRouter* const router_;
|
|
std::atomic<CastSocket*> socket_{nullptr};
|
|
};
|
|
|
|
class ReceiverSocketsClient
|
|
: public ReceiverSocketFactory::Client,
|
|
public VirtualConnectionRouter::SocketErrorHandler {
|
|
public:
|
|
explicit ReceiverSocketsClient(VirtualConnectionRouter* router)
|
|
: router_(router) {}
|
|
virtual ~ReceiverSocketsClient() = default;
|
|
|
|
const IPEndpoint& endpoint() const { return endpoint_; }
|
|
CastSocket* socket() const { return socket_; }
|
|
|
|
// ReceiverSocketFactory::Client overrides.
|
|
void OnConnected(ReceiverSocketFactory* factory,
|
|
const IPEndpoint& endpoint,
|
|
std::unique_ptr<CastSocket> socket) override {
|
|
OSP_CHECK(!socket_);
|
|
OSP_LOG_INFO << kLogDecorator
|
|
<< "Receiver got connection from endpoint: " << endpoint;
|
|
endpoint_ = endpoint;
|
|
socket_ = socket.get();
|
|
router_->TakeSocket(this, std::move(socket));
|
|
}
|
|
|
|
void OnError(ReceiverSocketFactory* factory, Error error) override {
|
|
OSP_LOG_FATAL << error;
|
|
}
|
|
|
|
// VirtualConnectionRouter::SocketErrorHandler overrides.
|
|
void OnClose(CastSocket* socket) override {
|
|
socket_ = nullptr;
|
|
OnCloseMock(socket);
|
|
}
|
|
void OnError(CastSocket* socket, Error error) override {
|
|
socket_ = nullptr;
|
|
OnErrorMock(socket, std::move(error));
|
|
}
|
|
|
|
MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
|
|
MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
|
|
|
|
private:
|
|
VirtualConnectionRouter* router_;
|
|
IPEndpoint endpoint_;
|
|
std::atomic<CastSocket*> socket_{nullptr};
|
|
};
|
|
|
|
class CastSocketE2ETest : public ::testing::Test {
|
|
public:
|
|
void SetUp() override {
|
|
PlatformClientPosix::Create(std::chrono::milliseconds(10));
|
|
task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
|
|
|
|
sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
|
|
sender_client_ =
|
|
std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get());
|
|
sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
|
|
task_runner_, sender_client_.get(), task_runner_);
|
|
sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
|
|
task_runner_,
|
|
TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_)
|
|
.release());
|
|
sender_factory_->set_factory(sender_tls_factory_.get());
|
|
|
|
ErrorOr<GeneratedCredentials> creds =
|
|
GenerateCredentialsForTesting("Device ID");
|
|
ASSERT_TRUE(creds.is_value());
|
|
credentials_ = std::move(creds.value());
|
|
|
|
CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der);
|
|
auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
|
|
task_runner_, credentials_.provider.get());
|
|
receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
|
|
receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
|
|
auth_handler_.get());
|
|
receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>(
|
|
receiver_router_.get());
|
|
receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>(
|
|
task_runner_, receiver_client_.get(), receiver_router_.get());
|
|
|
|
receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
|
|
task_runner_, TlsConnectionFactory::CreateFactory(
|
|
receiver_factory_.get(), task_runner_)
|
|
.release());
|
|
}
|
|
|
|
void TearDown() override {
|
|
OSP_LOG_INFO << "Shutting down";
|
|
sender_router_.reset();
|
|
receiver_router_.reset();
|
|
receiver_tls_factory_.reset();
|
|
receiver_factory_.reset();
|
|
auth_handler_.reset();
|
|
sender_tls_factory_.reset();
|
|
sender_factory_.reset();
|
|
CastTrustStore::ResetInstance();
|
|
PlatformClientPosix::ShutDown();
|
|
}
|
|
|
|
protected:
|
|
IPAddress GetLoopbackV4Address() {
|
|
absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
|
|
OSP_CHECK(loopback);
|
|
IPAddress address = loopback->GetIpAddressV4();
|
|
OSP_CHECK(address);
|
|
return address;
|
|
}
|
|
|
|
IPAddress GetLoopbackV6Address() {
|
|
absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
|
|
OSP_CHECK(loopback);
|
|
IPAddress address = loopback->GetIpAddressV6();
|
|
return address;
|
|
}
|
|
|
|
void Connect(const IPAddress& address) {
|
|
uint16_t port = 65321;
|
|
OSP_LOG_INFO << kLogDecorator << "Starting socket factories";
|
|
task_runner_->PostTask([this, &address, port]() {
|
|
OSP_LOG_INFO << kLogDecorator << "Receiver TLS factory Listen()";
|
|
receiver_tls_factory_->SetListenCredentials(credentials_.tls_credentials);
|
|
receiver_tls_factory_->Listen(IPEndpoint{address, port},
|
|
TlsListenOptions{1u});
|
|
});
|
|
|
|
task_runner_->PostTask([this, &address, port]() {
|
|
OSP_LOG_INFO << kLogDecorator << "Sender CastSocket factory Connect()";
|
|
sender_factory_->Connect(IPEndpoint{address, port},
|
|
SenderSocketFactory::DeviceMediaPolicy::kNone,
|
|
sender_router_.get());
|
|
});
|
|
|
|
WaitForCondition([this]() { return sender_client_->socket(); });
|
|
}
|
|
|
|
void ConnectSocketsV4() {
|
|
OSP_LOG_INFO << "Getting loopback IPv4 address";
|
|
IPAddress loopback_address = GetLoopbackV4Address();
|
|
OSP_LOG_INFO << "Connecting CastSockets";
|
|
Connect(loopback_address);
|
|
}
|
|
|
|
template <typename SocketClient, typename PeerSocketClient>
|
|
void CloseSocketsFromOneEnd(VirtualConnectionRouter* router,
|
|
SocketClient* client,
|
|
PeerSocketClient* peer_client) {
|
|
// TODO(issuetracker.google.com/169967989): Would like to have a symmetric
|
|
// OnClose check.
|
|
EXPECT_CALL(*client, OnCloseMock(client->socket()));
|
|
EXPECT_CALL(*peer_client, OnErrorMock(peer_client->socket(), _))
|
|
.WillOnce([](CastSocket* socket, Error error) {
|
|
EXPECT_EQ(error.code(), Error::Code::kSocketClosedFailure);
|
|
});
|
|
int32_t id = client->socket()->socket_id();
|
|
std::atomic_bool did_run{false};
|
|
task_runner_->PostTask([id, router, &did_run]() {
|
|
router->CloseSocket(id);
|
|
did_run = true;
|
|
});
|
|
OSP_LOG_INFO << "Waiting for socket to close";
|
|
WaitForCondition([&did_run]() { return did_run.load(); });
|
|
EXPECT_FALSE(sender_client_->socket());
|
|
EXPECT_FALSE(receiver_client_->socket());
|
|
}
|
|
|
|
TaskRunner* task_runner_;
|
|
|
|
// NOTE: Sender components.
|
|
SerialDeletePtr<VirtualConnectionRouter> sender_router_;
|
|
std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_;
|
|
SerialDeletePtr<SenderSocketFactory> sender_factory_;
|
|
SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
|
|
|
|
// NOTE: Receiver components.
|
|
SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
|
|
GeneratedCredentials credentials_;
|
|
SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
|
|
std::unique_ptr<StrictMock<ReceiverSocketsClient>> receiver_client_;
|
|
SerialDeletePtr<ReceiverSocketFactory> receiver_factory_;
|
|
SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_;
|
|
};
|
|
|
|
// These test the most basic setup of a complete CastSocket. This means
|
|
// constructing both a SenderSocketFactory and ReceiverSocketFactory, making a
|
|
// TLS connection to a known port over the loopback device, and checking device
|
|
// authentication.
|
|
TEST_F(CastSocketE2ETest, ConnectV4) {
|
|
ConnectSocketsV4();
|
|
}
|
|
|
|
TEST_F(CastSocketE2ETest, ConnectV6) {
|
|
OSP_LOG_INFO << "Getting loopback IPv6 address";
|
|
IPAddress loopback_address = GetLoopbackV6Address();
|
|
if (loopback_address) {
|
|
OSP_LOG_INFO << "Connecting CastSockets";
|
|
Connect(loopback_address);
|
|
} else {
|
|
OSP_LOG_WARN << "Test skipped due to missing IPv6 loopback address";
|
|
}
|
|
}
|
|
|
|
TEST_F(CastSocketE2ETest, SenderClose) {
|
|
ConnectSocketsV4();
|
|
|
|
CloseSocketsFromOneEnd(sender_router_.get(), sender_client_.get(),
|
|
receiver_client_.get());
|
|
}
|
|
|
|
TEST_F(CastSocketE2ETest, ReceiverClose) {
|
|
ConnectSocketsV4();
|
|
|
|
CloseSocketsFromOneEnd(receiver_router_.get(), receiver_client_.get(),
|
|
sender_client_.get());
|
|
}
|
|
|
|
} // namespace cast
|
|
} // namespace openscreen
|