// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "cast/common/public/cast_socket.h" #include "cast/common/channel/message_framer.h" #include "cast/common/channel/proto/cast_channel.pb.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace openscreen { namespace cast { using ::cast::channel::CastMessage; namespace { using ::testing::_; using ::testing::Invoke; using ::testing::Return; class CastSocketTest : public ::testing::Test { public: void SetUp() override { message_.set_protocol_version(CastMessage::CASTV2_1_0); message_.set_source_id("source"); message_.set_destination_id("destination"); message_.set_namespace_("namespace"); message_.set_payload_type(CastMessage::STRING); message_.set_payload_utf8("payload"); ErrorOr> serialized_or_error = message_serialization::Serialize(message_); ASSERT_TRUE(serialized_or_error); frame_serial_ = std::move(serialized_or_error.value()); } protected: MockTlsConnection& connection() { return *fake_socket_.connection; } MockCastSocketClient& mock_client() { return fake_socket_.mock_client; } CastSocket& socket() { return fake_socket_.socket; } FakeCastSocket fake_socket_; CastMessage message_; std::vector frame_serial_; }; } // namespace TEST_F(CastSocketTest, SendMessage) { EXPECT_CALL(connection(), Send(_, _)) .WillOnce(Invoke([this](const void* data, size_t len) { EXPECT_EQ( frame_serial_, std::vector(reinterpret_cast(data), reinterpret_cast(data) + len)); return true; })); ASSERT_TRUE(socket().Send(message_).ok()); } TEST_F(CastSocketTest, SendMessageEventuallyBlocks) { EXPECT_CALL(connection(), Send(_, _)) .Times(3) .WillRepeatedly(Invoke([this](const void* data, size_t len) { EXPECT_EQ( frame_serial_, std::vector(reinterpret_cast(data), reinterpret_cast(data) + len)); return true; })) .RetiresOnSaturation(); ASSERT_TRUE(socket().Send(message_).ok()); ASSERT_TRUE(socket().Send(message_).ok()); ASSERT_TRUE(socket().Send(message_).ok()); EXPECT_CALL(connection(), Send(_, _)).WillOnce(Return(false)); ASSERT_EQ(socket().Send(message_).code(), Error::Code::kAgain); } TEST_F(CastSocketTest, ReadCompleteMessage) { const uint8_t* data = frame_serial_.data(); EXPECT_CALL(mock_client(), OnMessage(_, _)) .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); })); connection().OnRead(std::vector(data, data + frame_serial_.size())); } TEST_F(CastSocketTest, ReadChunkedMessage) { const uint8_t* data = frame_serial_.data(); EXPECT_CALL(mock_client(), OnMessage(_, _)).Times(0); connection().OnRead(std::vector(data, data + 10)); EXPECT_CALL(mock_client(), OnMessage(_, _)) .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); })); connection().OnRead( std::vector(data + 10, data + frame_serial_.size())); std::vector double_message; double_message.insert(double_message.end(), frame_serial_.begin(), frame_serial_.end()); double_message.insert(double_message.end(), frame_serial_.begin(), frame_serial_.end()); data = double_message.data(); EXPECT_CALL(mock_client(), OnMessage(_, _)) .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); })); connection().OnRead( std::vector(data, data + frame_serial_.size() + 10)); EXPECT_CALL(mock_client(), OnMessage(_, _)) .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); })); connection().OnRead(std::vector(data + frame_serial_.size() + 10, data + double_message.size())); } TEST_F(CastSocketTest, ReadMultipleMessagesPerBlock) { CastMessage message2; std::vector frame_serial2; message2.set_protocol_version(CastMessage::CASTV2_1_0); message2.set_source_id("alt-source"); message2.set_destination_id("alt-destination"); message2.set_namespace_("alt-namespace"); message2.set_payload_type(CastMessage::STRING); message2.set_payload_utf8("alternate payload"); ErrorOr> serialized_or_error = message_serialization::Serialize(message2); ASSERT_TRUE(serialized_or_error); frame_serial2 = std::move(serialized_or_error.value()); std::vector send_data; send_data.reserve(frame_serial_.size() + frame_serial2.size()); send_data.insert(send_data.end(), frame_serial_.begin(), frame_serial_.end()); send_data.insert(send_data.end(), frame_serial2.begin(), frame_serial2.end()); EXPECT_CALL(mock_client(), OnMessage(_, _)) .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) { EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString()); })) .WillOnce([message2](CastSocket* socket, CastMessage message) { EXPECT_EQ(message2.SerializeAsString(), message.SerializeAsString()); }); connection().OnRead(std::move(send_data)); } TEST_F(CastSocketTest, SanitizedAddress) { std::array result1 = socket().GetSanitizedIpAddress(); EXPECT_EQ(result1[0], 1u); EXPECT_EQ(result1[1], 9u); FakeCastSocket v6_socket(IPEndpoint{{1, 2, 3, 4}, 1025}, IPEndpoint{{0x1819, 0x1a1b, 0x1c1d, 0x1e1f, 0x207b, 0x7c7d, 0x7e7f, 0x8081}, 4321}); std::array result2 = v6_socket.socket.GetSanitizedIpAddress(); EXPECT_EQ(result2[0], 128); EXPECT_EQ(result2[1], 129); } } // namespace cast } // namespace openscreen