You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
359 lines
13 KiB
359 lines
13 KiB
// Copyright 2020 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "discovery/mdns/mdns_probe_manager.h"
|
|
|
|
#include <utility>
|
|
|
|
#include "discovery/common/config.h"
|
|
#include "discovery/mdns/mdns_probe.h"
|
|
#include "discovery/mdns/mdns_querier.h"
|
|
#include "discovery/mdns/mdns_random.h"
|
|
#include "discovery/mdns/mdns_receiver.h"
|
|
#include "discovery/mdns/mdns_sender.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
#include "platform/test/fake_clock.h"
|
|
#include "platform/test/fake_task_runner.h"
|
|
#include "platform/test/fake_udp_socket.h"
|
|
|
|
using testing::_;
|
|
using testing::Invoke;
|
|
using testing::Return;
|
|
using testing::StrictMock;
|
|
|
|
namespace openscreen {
|
|
namespace discovery {
|
|
|
|
class MockDomainConfirmedProvider : public MdnsDomainConfirmedProvider {
|
|
public:
|
|
MOCK_METHOD2(OnDomainFound, void(const DomainName&, const DomainName&));
|
|
};
|
|
|
|
class MockMdnsSender : public MdnsSender {
|
|
public:
|
|
explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
|
|
|
|
MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
|
|
MOCK_METHOD2(SendMessage,
|
|
Error(const MdnsMessage& message, const IPEndpoint& endpoint));
|
|
};
|
|
|
|
class MockMdnsProbe : public MdnsProbe {
|
|
public:
|
|
MockMdnsProbe(DomainName target_name, IPAddress address)
|
|
: MdnsProbe(std::move(target_name), std::move(address)) {}
|
|
|
|
MOCK_METHOD1(Postpone, void(std::chrono::seconds));
|
|
MOCK_METHOD1(OnMessageReceived, void(const MdnsMessage&));
|
|
};
|
|
|
|
class TestMdnsProbeManager : public MdnsProbeManagerImpl {
|
|
public:
|
|
using MdnsProbeManagerImpl::MdnsProbeManagerImpl;
|
|
|
|
using MdnsProbeManagerImpl::OnProbeFailure;
|
|
using MdnsProbeManagerImpl::OnProbeSuccess;
|
|
|
|
std::unique_ptr<MdnsProbe> CreateProbe(DomainName name,
|
|
IPAddress address) override {
|
|
return std::make_unique<StrictMock<MockMdnsProbe>>(std::move(name),
|
|
std::move(address));
|
|
}
|
|
|
|
StrictMock<MockMdnsProbe>* GetOngoingMockProbeByTarget(
|
|
const DomainName& target) {
|
|
const auto it =
|
|
std::find_if(ongoing_probes_.begin(), ongoing_probes_.end(),
|
|
[&target](const OngoingProbe& ongoing) {
|
|
return ongoing.probe->target_name() == target;
|
|
});
|
|
if (it != ongoing_probes_.end()) {
|
|
return static_cast<StrictMock<MockMdnsProbe>*>(it->probe.get());
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
StrictMock<MockMdnsProbe>* GetCompletedMockProbe(const DomainName& target) {
|
|
const auto it = FindCompletedProbe(target);
|
|
if (it != completed_probes_.end()) {
|
|
return static_cast<StrictMock<MockMdnsProbe>*>(it->get());
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool HasOngoingProbe(const DomainName& target) {
|
|
return GetOngoingMockProbeByTarget(target) != nullptr;
|
|
}
|
|
|
|
bool HasCompletedProbe(const DomainName& target) {
|
|
return GetCompletedMockProbe(target) != nullptr;
|
|
}
|
|
|
|
size_t GetOngoingProbeCount() { return ongoing_probes_.size(); }
|
|
|
|
size_t GetCompletedProbeCount() { return completed_probes_.size(); }
|
|
};
|
|
|
|
class MdnsProbeManagerTests : public testing::Test {
|
|
public:
|
|
MdnsProbeManagerTests()
|
|
: clock_(Clock::now()),
|
|
task_runner_(&clock_),
|
|
socket_(&task_runner_),
|
|
sender_(&socket_),
|
|
receiver_(config_),
|
|
manager_(&sender_,
|
|
&receiver_,
|
|
&random_,
|
|
&task_runner_,
|
|
FakeClock::now) {
|
|
ExpectProbeStopped(name_);
|
|
ExpectProbeStopped(name2_);
|
|
ExpectProbeStopped(name_retry_);
|
|
}
|
|
|
|
protected:
|
|
MdnsMessage CreateProbeQueryMessage(DomainName domain,
|
|
const IPAddress& address) {
|
|
MdnsMessage message(CreateMessageId(), MessageType::Query);
|
|
MdnsQuestion question(domain, DnsType::kANY, DnsClass::kANY,
|
|
ResponseType::kUnicast);
|
|
MdnsRecord record = CreateAddressRecord(std::move(domain), address);
|
|
message.AddQuestion(std::move(question));
|
|
message.AddAuthorityRecord(std::move(record));
|
|
return message;
|
|
}
|
|
|
|
void ExpectProbeStopped(const DomainName& name) {
|
|
EXPECT_FALSE(manager_.HasOngoingProbe(name));
|
|
EXPECT_FALSE(manager_.HasCompletedProbe(name));
|
|
EXPECT_FALSE(manager_.IsDomainClaimed(name));
|
|
}
|
|
|
|
StrictMock<MockMdnsProbe>* ExpectProbeOngoing(const DomainName& name) {
|
|
// Get around limitations of using an assert in a function with a return
|
|
// value.
|
|
auto validate = [this, &name]() {
|
|
ASSERT_TRUE(manager_.HasOngoingProbe(name));
|
|
EXPECT_FALSE(manager_.HasCompletedProbe(name));
|
|
EXPECT_FALSE(manager_.IsDomainClaimed(name));
|
|
};
|
|
validate();
|
|
|
|
return manager_.GetOngoingMockProbeByTarget(name);
|
|
}
|
|
|
|
StrictMock<MockMdnsProbe>* ExpectProbeCompleted(const DomainName& name) {
|
|
// Get around limitations of using an assert in a function with a return
|
|
// value.
|
|
auto validate = [this, &name]() {
|
|
EXPECT_FALSE(manager_.HasOngoingProbe(name));
|
|
ASSERT_TRUE(manager_.HasCompletedProbe(name));
|
|
EXPECT_TRUE(manager_.IsDomainClaimed(name));
|
|
};
|
|
validate();
|
|
|
|
return manager_.GetCompletedMockProbe(name);
|
|
}
|
|
|
|
StrictMock<MockMdnsProbe>* SetUpCompletedProbe(const DomainName& name,
|
|
const IPAddress& address) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name, address).ok());
|
|
EXPECT_CALL(callback_, OnDomainFound(name, name));
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name);
|
|
manager_.OnProbeSuccess(ongoing_probe);
|
|
ExpectProbeCompleted(name);
|
|
testing::Mock::VerifyAndClearExpectations(ongoing_probe);
|
|
|
|
return ongoing_probe;
|
|
}
|
|
|
|
Config config_;
|
|
FakeClock clock_;
|
|
FakeTaskRunner task_runner_;
|
|
FakeUdpSocket socket_;
|
|
StrictMock<MockMdnsSender> sender_;
|
|
MdnsReceiver receiver_;
|
|
MdnsRandom random_;
|
|
StrictMock<TestMdnsProbeManager> manager_;
|
|
MockDomainConfirmedProvider callback_;
|
|
|
|
const DomainName name_{"test", "_googlecast", "_tcp", "local"};
|
|
const DomainName name_retry_{"test1", "_googlecast", "_tcp", "local"};
|
|
const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
|
|
|
|
// When used to create address records A, B, C, A > B because comparison of
|
|
// the rdata in each results in the comparison of endpoints, for which
|
|
// address_b_ < address_a_. A < C because A is DnsType kA with value 1 and
|
|
// C is DnsType kAAAA with value 28.
|
|
const IPAddress address_a_{192, 168, 0, 0};
|
|
const IPAddress address_b_{190, 160, 0, 0};
|
|
const IPAddress address_c_{0x0102, 0x0304, 0x0506, 0x0708,
|
|
0x090a, 0x0b0c, 0x0d0e, 0x0f10};
|
|
const IPEndpoint endpoint_{{192, 168, 0, 0}, 80};
|
|
};
|
|
|
|
TEST_F(MdnsProbeManagerTests, StartProbeBeginsProbeWhenNoneExistsOnly) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
ExpectProbeOngoing(name_);
|
|
EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
|
|
|
|
EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
|
|
EXPECT_CALL(callback_, OnDomainFound(name_, name_));
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
manager_.OnProbeSuccess(ongoing_probe);
|
|
EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
|
|
testing::Mock::VerifyAndClearExpectations(ongoing_probe);
|
|
|
|
EXPECT_FALSE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
|
|
StrictMock<MockMdnsProbe>* completed_probe = ExpectProbeCompleted(name_);
|
|
EXPECT_EQ(ongoing_probe, completed_probe);
|
|
EXPECT_FALSE(manager_.IsDomainClaimed(name2_));
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, StopProbeChangesOngoingProbesOnly) {
|
|
EXPECT_FALSE(manager_.StopProbe(name_).ok());
|
|
ExpectProbeStopped(name_);
|
|
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
ExpectProbeOngoing(name_);
|
|
|
|
EXPECT_TRUE(manager_.StopProbe(name_).ok());
|
|
ExpectProbeStopped(name_);
|
|
|
|
SetUpCompletedProbe(name_, address_a_);
|
|
|
|
EXPECT_FALSE(manager_.StopProbe(name_).ok());
|
|
ExpectProbeCompleted(name_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, RespondToProbeQuerySendsNothingOnUnownedDomain) {
|
|
const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, RespondToProbeQueryWorksForCompletedProbes) {
|
|
SetUpCompletedProbe(name_, address_a_);
|
|
|
|
const MdnsMessage query = CreateProbeQueryMessage(name_, address_c_);
|
|
EXPECT_CALL(sender_, SendMessage(_, endpoint_))
|
|
.WillOnce([this](const MdnsMessage& message,
|
|
const IPEndpoint& endpoint) -> Error {
|
|
EXPECT_EQ(message.answers().size(), size_t{1});
|
|
EXPECT_EQ(message.answers()[0].dns_type(), DnsType::kA);
|
|
EXPECT_EQ(message.answers()[0].name(), this->name_);
|
|
return Error::None();
|
|
});
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForSingleRecordQueries) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
|
|
// If the probe message received matches the currently running probe, do
|
|
// nothing.
|
|
MdnsMessage query = CreateProbeQueryMessage(name_, address_a_);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
|
|
// If the probe message received is less than the ongoing probe, ignore the
|
|
// incoming probe.
|
|
query = CreateProbeQueryMessage(name_, address_b_);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
|
|
// If the probe message received is greater than the ongoing probe, postpone
|
|
// the currently running probe.
|
|
query = CreateProbeQueryMessage(name_, address_c_);
|
|
EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, TiebreakProbeQueryWorksForMultiRecordQueries) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
|
|
// For the below tests, note that if records A, B, C are generated from
|
|
// addresses |address_a_|, |address_b_|, and |address_c_| respectively,
|
|
// then B < A < C.
|
|
//
|
|
// If the received records have one record less than the tested record, they
|
|
// are sorted and the lowest record is compared.
|
|
MdnsMessage query = CreateProbeQueryMessage(name_, address_b_);
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
|
|
query = CreateProbeQueryMessage(name_, address_c_);
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
|
|
query = CreateProbeQueryMessage(name_, address_a_);
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_b_));
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
|
|
// If the probe message received has the same first record as what's being
|
|
// compared and the query has more records, the query wins.
|
|
query = CreateProbeQueryMessage(name_, address_a_);
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_c_));
|
|
EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
testing::Mock::VerifyAndClearExpectations(ongoing_probe);
|
|
|
|
query = CreateProbeQueryMessage(name_, address_c_);
|
|
query.AddAuthorityRecord(CreateAddressRecord(name_, address_a_));
|
|
EXPECT_CALL(*ongoing_probe, Postpone(_)).Times(1);
|
|
manager_.RespondToProbeQuery(query, endpoint_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, ProbeSuccessAfterProbeRemovalNoOp) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
EXPECT_TRUE(manager_.StopProbe(name_).ok());
|
|
manager_.OnProbeSuccess(ongoing_probe);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, ProbeFailureAfterProbeRemovalNoOp) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
EXPECT_TRUE(manager_.StopProbe(name_).ok());
|
|
manager_.OnProbeFailure(ongoing_probe);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, ProbeFailureCallsCallbackWhenAlreadyClaimed) {
|
|
// This test first starts a probe with domain |name_retry_| so that when
|
|
// probe with domain |name_| fails, the newly generated domain with equal
|
|
// |name_retry_|.
|
|
StrictMock<MockMdnsProbe>* ongoing_probe =
|
|
SetUpCompletedProbe(name_retry_, address_a_);
|
|
|
|
// Because |name_retry_| has already succeeded, the retry logic should skip
|
|
// over re-querying for |name_retry_| and jump right to success.
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
ongoing_probe = ExpectProbeOngoing(name_);
|
|
EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
|
|
manager_.OnProbeFailure(ongoing_probe);
|
|
ExpectProbeStopped(name_);
|
|
ExpectProbeCompleted(name_retry_);
|
|
}
|
|
|
|
TEST_F(MdnsProbeManagerTests, ProbeFailureCreatesNewProbeIfNameUnclaimed) {
|
|
EXPECT_TRUE(manager_.StartProbe(&callback_, name_, address_a_).ok());
|
|
StrictMock<MockMdnsProbe>* ongoing_probe = ExpectProbeOngoing(name_);
|
|
manager_.OnProbeFailure(ongoing_probe);
|
|
ExpectProbeStopped(name_);
|
|
ongoing_probe = ExpectProbeOngoing(name_retry_);
|
|
EXPECT_EQ(ongoing_probe->target_name(), name_retry_);
|
|
|
|
EXPECT_CALL(callback_, OnDomainFound(name_, name_retry_));
|
|
manager_.OnProbeSuccess(ongoing_probe);
|
|
ExpectProbeCompleted(name_retry_);
|
|
ExpectProbeStopped(name_);
|
|
}
|
|
|
|
} // namespace discovery
|
|
} // namespace openscreen
|