// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/mdns/mdns_responder.h" #include #include "discovery/common/config.h" #include "discovery/mdns/mdns_probe_manager.h" #include "discovery/mdns/mdns_random.h" #include "discovery/mdns/mdns_receiver.h" #include "discovery/mdns/mdns_records.h" #include "discovery/mdns/mdns_sender.h" #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "platform/test/fake_udp_socket.h" namespace openscreen { namespace discovery { namespace { constexpr Clock::duration kMaximumSharedRecordResponseDelayMs(120 * 1000); bool ContainsRecordType(const std::vector& records, DnsType type) { return std::find_if(records.begin(), records.end(), [type](const MdnsRecord& record) { return record.dns_type() == type; }) != records.end(); } void CheckSingleNsecRecordType(const MdnsMessage& message, DnsType type) { ASSERT_EQ(message.answers().size(), size_t{1}); const MdnsRecord record = message.answers()[0]; ASSERT_EQ(record.dns_type(), DnsType::kNSEC); const NsecRecordRdata& rdata = absl::get(record.rdata()); ASSERT_EQ(rdata.types().size(), size_t{1}); EXPECT_EQ(rdata.types()[0], type); } void CheckPtrDomain(const MdnsRecord& record, const DomainName& domain) { ASSERT_EQ(record.dns_type(), DnsType::kPTR); const PtrRecordRdata& rdata = absl::get(record.rdata()); EXPECT_EQ(rdata.ptr_domain(), domain); } void ExpectContainsNsecRecordType(const std::vector& records, DnsType type) { auto it = std::find_if( records.begin(), records.end(), [type](const MdnsRecord& record) { if (record.dns_type() != DnsType::kNSEC) { return false; } const NsecRecordRdata& rdata = absl::get(record.rdata()); return rdata.types().size() == 1 && rdata.types()[0] == type; }); EXPECT_TRUE(it != records.end()); } } // namespace using testing::_; using testing::Args; using testing::Invoke; using testing::Return; using testing::StrictMock; class MockRecordHandler : public MdnsResponder::RecordHandler { public: void AddRecord(MdnsRecord record) { records_.push_back(record); } MOCK_METHOD3(HasRecords, bool(const DomainName&, DnsType, DnsClass)); std::vector GetRecords(const DomainName& name, DnsType type, DnsClass clazz) override { std::vector records; for (const auto& record : records_) { if (type == DnsType::kANY || record.dns_type() == type) { records.push_back(record); } } return records; } std::vector GetPtrRecords(DnsClass clazz) override { std::vector records; for (const auto& record : records_) { if (record.dns_type() == DnsType::kPTR) { records.push_back(record); } } return records; } private: std::vector records_; }; class MockMdnsSender : public MdnsSender { public: explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {} MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message)); MOCK_METHOD2(SendMessage, Error(const MdnsMessage& message, const IPEndpoint& endpoint)); }; class MockProbeManager : public MdnsProbeManager { public: MOCK_CONST_METHOD1(IsDomainClaimed, bool(const DomainName&)); MOCK_METHOD2(RespondToProbeQuery, void(const MdnsMessage&, const IPEndpoint&)); }; class MdnsResponderTest : public testing::Test { public: MdnsResponderTest() : clock_(Clock::now()), task_runner_(&clock_), socket_(&task_runner_), sender_(&socket_), receiver_(config_), responder_(&record_handler_, &probe_manager_, &sender_, &receiver_, &task_runner_, FakeClock::now, &random_, config_) {} protected: MdnsRecord GetFakePtrRecord(const DomainName& target) { DomainName name(++target.labels().begin(), target.labels().end()); PtrRecordRdata rdata(target); return MdnsRecord(std::move(name), DnsType::kPTR, DnsClass::kIN, RecordType::kUnique, std::chrono::seconds(0), rdata); } MdnsRecord GetFakeSrvRecord(const DomainName& name) { SrvRecordRdata rdata(0, 0, 80, name); return MdnsRecord(name, DnsType::kSRV, DnsClass::kIN, RecordType::kUnique, std::chrono::seconds(0), rdata); } MdnsRecord GetFakeTxtRecord(const DomainName& name) { TxtRecordRdata rdata; return MdnsRecord(name, DnsType::kTXT, DnsClass::kIN, RecordType::kUnique, std::chrono::seconds(0), rdata); } MdnsRecord GetFakeARecord(const DomainName& name) { ARecordRdata rdata(IPAddress(192, 168, 0, 0)); return MdnsRecord(name, DnsType::kA, DnsClass::kIN, RecordType::kUnique, std::chrono::seconds(0), rdata); } MdnsRecord GetFakeAAAARecord(const DomainName& name) { AAAARecordRdata rdata(IPAddress(1, 2, 3, 4, 5, 6, 7, 8)); return MdnsRecord(name, DnsType::kAAAA, DnsClass::kIN, RecordType::kUnique, std::chrono::seconds(0), rdata); } void OnMessageReceived(const MdnsMessage& message, const IPEndpoint& src) { responder_.OnMessageReceived(message, src); } void QueryForRecordTypeWhenNonePresent(DnsType type) { MdnsQuestion question(domain_, type, DnsClass::kANY, ResponseType::kMulticast); MdnsMessage message(0, MessageType::Query); message.AddQuestion(question); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([type](const MdnsMessage& msg) -> Error { CheckSingleNsecRecordType(msg, type); return Error::None(); }); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); OnMessageReceived(message, endpoint_); } MdnsMessage CreateMulticastMdnsQuery(DnsType type) { MdnsQuestion question(domain_, type, DnsClass::kANY, ResponseType::kMulticast); MdnsMessage message(0, MessageType::Query); message.AddQuestion(std::move(question)); return message; } MdnsMessage CreateTypeEnumerationQuery() { MdnsQuestion question(type_enumeration_domain_, DnsType::kPTR, DnsClass::kANY, ResponseType::kMulticast); MdnsMessage message(0, MessageType::Query); message.AddQuestion(std::move(question)); return message; } const Config config_; FakeClock clock_; FakeTaskRunner task_runner_; FakeUdpSocket socket_; StrictMock sender_; StrictMock record_handler_; StrictMock probe_manager_; MdnsReceiver receiver_; MdnsRandom random_; MdnsResponder responder_; DomainName domain_{"instance", "_googlecast", "_tcp", "local"}; DomainName type_enumeration_domain_{"_services", "_dns-sd", "_udp", "local"}; IPEndpoint endpoint_{IPAddress(192, 168, 0, 0), 80}; }; // Validate that when records may be sent from multiple receivers, the broadcast // is delayed and it is not delayed otherwise. TEST_F(MdnsResponderTest, OwnedRecordsSentImmediately) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); OnMessageReceived(message, endpoint_); testing::Mock::VerifyAndClearExpectations(&sender_); testing::Mock::VerifyAndClearExpectations(&record_handler_); testing::Mock::VerifyAndClearExpectations(&probe_manager_); EXPECT_CALL(sender_, SendMulticast(_)).Times(0); clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); } TEST_F(MdnsResponderTest, NonOwnedRecordsDelayed) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)).Times(0); OnMessageReceived(message, endpoint_); testing::Mock::VerifyAndClearExpectations(&sender_); testing::Mock::VerifyAndClearExpectations(&record_handler_); testing::Mock::VerifyAndClearExpectations(&probe_manager_); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); } TEST_F(MdnsResponderTest, MultipleQuestionsProcessed) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); MdnsQuestion question2(domain_, DnsType::kANY, DnsClass::kANY, ResponseType::kMulticast); message.AddQuestion(std::move(question2)); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)) .WillOnce(Return(true)) .WillOnce(Return(false)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); OnMessageReceived(message, endpoint_); testing::Mock::VerifyAndClearExpectations(&sender_); testing::Mock::VerifyAndClearExpectations(&record_handler_); testing::Mock::VerifyAndClearExpectations(&probe_manager_); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); } // Validate that the correct messaging scheme (unicast vs multicast) is used. TEST_F(MdnsResponderTest, UnicastMessageSentOverUnicast) { MdnsQuestion question(domain_, DnsType::kANY, DnsClass::kANY, ResponseType::kUnicast); MdnsMessage message(0, MessageType::Query); message.AddQuestion(question); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMessage(_, endpoint_)).Times(1); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, MulticastMessageSentOverMulticast) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)).Times(1); OnMessageReceived(message, endpoint_); } // Validate that records are added as expected based on the query type, and that // additional records are populated as specified in RFC 6762 and 6763. TEST_F(MdnsResponderTest, AnyQueryResultsAllApplied) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{4}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kTXT)); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); EXPECT_FALSE(ContainsRecordType(message.answers(), DnsType::kPTR)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, PtrQueryResultsApplied) { DomainName ptr_domain{"_googlecast", "_tcp", "local"}; MdnsQuestion question(ptr_domain, DnsType::kPTR, DnsClass::kANY, ResponseType::kMulticast); MdnsMessage message(0, MessageType::Query); message.AddQuestion(question); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{4}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); const auto& records = message.additional_records(); EXPECT_EQ(records.size(), size_t{4}); EXPECT_TRUE(ContainsRecordType(records, DnsType::kSRV)); EXPECT_TRUE(ContainsRecordType(records, DnsType::kTXT)); EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, SrvQueryResultsApplied) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{2}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); const auto& records = message.additional_records(); EXPECT_EQ(records.size(), size_t{2}); EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, AQueryResultsApplied) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{1}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); const auto& records = message.additional_records(); EXPECT_EQ(records.size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(records, DnsType::kAAAA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, AAAAQueryResultsApplied) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{1}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); const auto& records = message.additional_records(); EXPECT_EQ(records.size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(records, DnsType::kA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kAAAA)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kSRV)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kTXT)); EXPECT_FALSE(ContainsRecordType(records, DnsType::kPTR)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, MessageOnlySentIfAnswerNotKnown) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message.AddAnswer(aaaa_record); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(aaaa_record); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnown) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message.AddAnswer(aaaa_record); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(aaaa_record); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePackets) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); message.set_truncated(); MdnsMessage message2(1, MessageType::Query); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message2.AddAnswer(aaaa_record); OnMessageReceived(message, endpoint_); OnMessageReceived(message2, endpoint_); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(aaaa_record); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); return Error::None(); }); clock_.Advance(std::chrono::seconds(1)); } TEST_F(MdnsResponderTest, RecordOnlySentIfNotKnownMultiplePacketsOutOfOrder) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); message.set_truncated(); MdnsMessage message2(2, MessageType::Query); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message2.AddAnswer(aaaa_record); message2.set_truncated(); MdnsMessage message3(3, MessageType::Query); MdnsRecord a_record = GetFakeARecord(domain_); message3.AddAnswer(a_record); OnMessageReceived(message2, endpoint_); OnMessageReceived(message3, endpoint_); OnMessageReceived(message, endpoint_); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(a_record); record_handler_.AddRecord(aaaa_record); record_handler_.AddRecord(aaaa_record); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); return Error::None(); }); clock_.Advance(std::chrono::seconds(1)); } TEST_F(MdnsResponderTest, RecordSentForMultiPacketsSuppressionIfMoreNotFound) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kANY); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message.AddAnswer(aaaa_record); message.set_truncated(); OnMessageReceived(message, endpoint_); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakeARecord(domain_)); record_handler_.AddRecord(aaaa_record); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.additional_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); return Error::None(); }); clock_.Advance(std::chrono::seconds(1)); } TEST_F(MdnsResponderTest, RecordNotSentForMultiPacketsSuppressionIfNoQuery) { MdnsMessage message(1, MessageType::Query); MdnsRecord aaaa_record = GetFakeAAAARecord(domain_); message.AddAnswer(aaaa_record); OnMessageReceived(message, endpoint_); clock_.Advance(std::chrono::seconds(1)); } // Validate NSEC records are used correctly. TEST_F(MdnsResponderTest, QueryForRecordTypesWhenNonePresent) { QueryForRecordTypeWhenNonePresent(DnsType::kANY); QueryForRecordTypeWhenNonePresent(DnsType::kSRV); QueryForRecordTypeWhenNonePresent(DnsType::kTXT); QueryForRecordTypeWhenNonePresent(DnsType::kA); QueryForRecordTypeWhenNonePresent(DnsType::kAAAA); } TEST_F(MdnsResponderTest, AAAAQueryGiveANsec) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kAAAA); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeAAAARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kAAAA)); EXPECT_EQ(message.additional_records().size(), size_t{1}); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, AQueryGiveAAAANsec) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kA); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kA)); EXPECT_EQ(message.additional_records().size(), size_t{1}); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kAAAA); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsecForNoAOrAAAA) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); EXPECT_EQ(message.additional_records().size(), size_t{2}); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kAAAA); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, SrvQueryGiveCorrectNsec) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kSRV); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kSRV)); EXPECT_EQ(message.additional_records().size(), size_t{2}); EXPECT_TRUE( ContainsRecordType(message.additional_records(), DnsType::kA)); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kAAAA); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForNoPtrOrSrv) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); EXPECT_EQ(message.additional_records().size(), size_t{2}); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kTXT); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kSRV); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlyPtr) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); EXPECT_EQ(message.additional_records().size(), size_t{2}); EXPECT_TRUE( ContainsRecordType(message.additional_records(), DnsType::kTXT)); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kSRV); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, PtrQueryGiveCorrectNsecForOnlySrv) { MdnsMessage message = CreateMulticastMdnsQuery(DnsType::kPTR); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(true)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); record_handler_.AddRecord(GetFakePtrRecord(domain_)); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); EXPECT_EQ(message.additional_records().size(), size_t{4}); EXPECT_TRUE( ContainsRecordType(message.additional_records(), DnsType::kSRV)); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kTXT); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kA); ExpectContainsNsecRecordType(message.additional_records(), DnsType::kAAAA); return Error::None(); }); OnMessageReceived(message, endpoint_); } TEST_F(MdnsResponderTest, EnumerateAllQuery) { MdnsMessage message = CreateTypeEnumerationQuery(); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); const auto ptr = GetFakePtrRecord(domain_); record_handler_.AddRecord(ptr); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); OnMessageReceived(message, endpoint_); EXPECT_CALL(sender_, SendMulticast(_)) .WillOnce([this, &ptr](const MdnsMessage& message) -> Error { EXPECT_EQ(message.questions().size(), size_t{0}); EXPECT_EQ(message.authority_records().size(), size_t{0}); EXPECT_EQ(message.answers().size(), size_t{1}); EXPECT_TRUE(ContainsRecordType(message.answers(), DnsType::kPTR)); EXPECT_EQ(message.answers()[0].name(), type_enumeration_domain_); CheckPtrDomain(message.answers()[0], ptr.name()); return Error::None(); }); clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); } TEST_F(MdnsResponderTest, EnumerateAllQueryNoResults) { MdnsMessage message = CreateTypeEnumerationQuery(); EXPECT_CALL(probe_manager_, IsDomainClaimed(_)).WillOnce(Return(false)); EXPECT_CALL(record_handler_, HasRecords(_, _, _)) .WillRepeatedly(Return(true)); const auto ptr = GetFakePtrRecord(domain_); record_handler_.AddRecord(GetFakeSrvRecord(domain_)); record_handler_.AddRecord(GetFakeTxtRecord(domain_)); record_handler_.AddRecord(GetFakeARecord(domain_)); OnMessageReceived(message, endpoint_); clock_.Advance(Clock::duration(kMaximumSharedRecordResponseDelayMs)); } } // namespace discovery } // namespace openscreen