You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
11 KiB

// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/dnssd/impl/publisher_impl.h"
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "discovery/common/reporting_client.h"
#include "discovery/dnssd/impl/conversion_layer.h"
#include "discovery/dnssd/impl/instance_key.h"
#include "discovery/dnssd/impl/network_interface_config.h"
#include "discovery/mdns/public/mdns_constants.h"
#include "platform/api/task_runner.h"
#include "platform/base/error.h"
#include "util/trace_logging.h"
namespace openscreen {
namespace discovery {
namespace {
DnsSdInstanceEndpoint CreateEndpoint(
DnsSdInstance instance,
InstanceKey key,
const NetworkInterfaceConfig& network_config) {
std::vector<IPEndpoint> endpoints;
if (network_config.HasAddressV4()) {
endpoints.push_back({network_config.address_v4(), instance.port()});
}
if (network_config.HasAddressV6()) {
endpoints.push_back({network_config.address_v6(), instance.port()});
}
return DnsSdInstanceEndpoint(
key.instance_id(), key.service_id(), key.domain_id(), instance.txt(),
network_config.network_interface(), std::move(endpoints));
}
DnsSdInstanceEndpoint UpdateDomain(
const DomainName& name,
DnsSdInstance instance,
const NetworkInterfaceConfig& network_config) {
return CreateEndpoint(std::move(instance), InstanceKey(name), network_config);
}
DnsSdInstanceEndpoint CreateEndpoint(
DnsSdInstance instance,
const NetworkInterfaceConfig& network_config) {
InstanceKey key(instance);
return CreateEndpoint(std::move(instance), std::move(key), network_config);
}
template <typename T>
inline typename std::map<DnsSdInstance, T>::iterator FindKey(
std::map<DnsSdInstance, T>* instances,
const InstanceKey& key) {
return std::find_if(instances->begin(), instances->end(),
[&key](const std::pair<DnsSdInstance, T>& pair) {
return key == InstanceKey(pair.first);
});
}
template <typename T>
int EraseInstancesWithServiceId(std::map<DnsSdInstance, T>* instances,
const std::string& service_id) {
int removed_count = 0;
for (auto it = instances->begin(); it != instances->end();) {
if (it->first.service_id() == service_id) {
removed_count++;
it = instances->erase(it);
} else {
it++;
}
}
return removed_count;
}
} // namespace
PublisherImpl::PublisherImpl(MdnsService* publisher,
ReportingClient* reporting_client,
TaskRunner* task_runner,
const NetworkInterfaceConfig* network_config)
: mdns_publisher_(publisher),
reporting_client_(reporting_client),
task_runner_(task_runner),
network_config_(network_config) {
OSP_DCHECK(mdns_publisher_);
OSP_DCHECK(reporting_client_);
OSP_DCHECK(task_runner_);
}
PublisherImpl::~PublisherImpl() = default;
Error PublisherImpl::Register(const DnsSdInstance& instance, Client* client) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DCHECK(client != nullptr);
if (published_instances_.find(instance) != published_instances_.end()) {
UpdateRegistration(instance);
} else if (pending_instances_.find(instance) != pending_instances_.end()) {
return Error::Code::kOperationInProgress;
}
InstanceKey key(instance);
const IPAddress& address = network_config_->GetAddress();
OSP_DCHECK(address);
pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
client);
OSP_DVLOG << "Registering instance '" << instance.instance_id() << "'";
return mdns_publisher_->StartProbe(this, GetDomainName(key), address);
}
Error PublisherImpl::UpdateRegistration(const DnsSdInstance& instance) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
// Check if the instance is still pending publication.
auto it = FindKey(&pending_instances_, InstanceKey(instance));
OSP_DVLOG << "Updating instance '" << instance.instance_id() << "'";
// If it is a pending instance, update it. Else, try to update a published
// instance.
if (it != pending_instances_.end()) {
// The instance, service, and domain ids have not changed, so only the
// remaining data needs to change. The ongoing probe does not need to be
// modified.
Client* const client = it->second;
pending_instances_.erase(it);
pending_instances_.emplace(CreateEndpoint(instance, *network_config_),
client);
return Error::None();
} else {
return UpdatePublishedRegistration(instance);
}
}
Error PublisherImpl::UpdatePublishedRegistration(
const DnsSdInstance& instance) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
auto published_instance_it =
FindKey(&published_instances_, InstanceKey(instance));
// Check preconditions called out in header. Specifically, the updated
// instance must be making changes to an already published instance.
if (published_instance_it == published_instances_.end()) {
return Error::Code::kParameterInvalid;
}
const DnsSdInstanceEndpoint updated_endpoint =
UpdateDomain(GetDomainName(InstanceKey(published_instance_it->second)),
instance, *network_config_);
if (published_instance_it->second == updated_endpoint) {
return Error::Code::kParameterInvalid;
}
// Get all instances which have changed. By design, there an only be one
// instance of each DnsType, so use that here to simplify this step. First in
// each pair is the old instances, second is the new instance.
std::map<DnsType,
std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>>>
changed_records;
const std::vector<MdnsRecord> old_records =
GetDnsRecords(published_instance_it->second);
const std::vector<MdnsRecord> new_records = GetDnsRecords(updated_endpoint);
// Populate the first part of each pair in |changed_instances|.
for (size_t i = 0; i < old_records.size(); i++) {
const auto key = old_records[i].dns_type();
OSP_DCHECK(changed_records.find(key) == changed_records.end());
auto value = std::make_pair(std::move(old_records[i]), absl::nullopt);
changed_records.emplace(key, std::move(value));
}
// Populate the second part of each pair in |changed_records|.
for (size_t i = 0; i < new_records.size(); i++) {
const auto key = new_records[i].dns_type();
auto find_it = changed_records.find(key);
if (find_it == changed_records.end()) {
std::pair<absl::optional<MdnsRecord>, absl::optional<MdnsRecord>> value(
absl::nullopt, std::move(new_records[i]));
changed_records.emplace(key, std::move(value));
} else {
find_it->second.second = std::move(new_records[i]);
}
}
// Apply changes called out in |changed_records|.
Error total_result = Error::None();
for (const auto& pair : changed_records) {
OSP_DCHECK(pair.second.first != absl::nullopt ||
pair.second.second != absl::nullopt);
if (pair.second.first == absl::nullopt) {
TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
auto error = mdns_publisher_->RegisterRecord(pair.second.second.value());
TRACE_SET_RESULT(error);
if (!error.ok()) {
total_result = error;
}
} else if (pair.second.second == absl::nullopt) {
TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
auto error = mdns_publisher_->UnregisterRecord(pair.second.first.value());
TRACE_SET_RESULT(error);
if (!error.ok()) {
total_result = error;
}
} else if (pair.second.first.value() != pair.second.second.value()) {
TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UpdateRegisteredRecord");
auto error = mdns_publisher_->UpdateRegisteredRecord(
pair.second.first.value(), pair.second.second.value());
TRACE_SET_RESULT(error);
if (!error.ok()) {
total_result = error;
}
}
}
// Replace the old instances with the new ones.
published_instances_.erase(published_instance_it);
published_instances_.emplace(instance, std::move(updated_endpoint));
return total_result;
}
ErrorOr<int> PublisherImpl::DeregisterAll(const std::string& service) {
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Deregistering all instances";
int removed_count = 0;
Error error = Error::None();
for (auto it = published_instances_.begin();
it != published_instances_.end();) {
if (it->second.service_id() == service) {
for (const auto& mdns_record : GetDnsRecords(it->second)) {
TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.UnregisterRecord");
auto publisher_error = mdns_publisher_->UnregisterRecord(mdns_record);
TRACE_SET_RESULT(error);
if (!publisher_error.ok()) {
error = publisher_error;
}
}
removed_count++;
it = published_instances_.erase(it);
} else {
it++;
}
}
removed_count += EraseInstancesWithServiceId(&pending_instances_, service);
if (!error.ok()) {
return error;
} else {
return removed_count;
}
}
void PublisherImpl::OnDomainFound(const DomainName& requested_name,
const DomainName& confirmed_name) {
TRACE_DEFAULT_SCOPED(TraceCategory::kDiscovery);
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
OSP_DVLOG << "Domain successfully claimed: '" << confirmed_name.ToString()
<< "' based on requested name: '" << requested_name.ToString()
<< "'";
auto it = FindKey(&pending_instances_, InstanceKey(requested_name));
if (it == pending_instances_.end()) {
// This will be hit if the instance was deregister'd before the probe phase
// was completed.
return;
}
DnsSdInstance requested_instance = std::move(it->first);
DnsSdInstanceEndpoint endpoint =
CreateEndpoint(requested_instance, *network_config_);
Client* const client = it->second;
pending_instances_.erase(it);
InstanceKey requested_key(requested_instance);
if (requested_name != confirmed_name) {
OSP_DCHECK(HasValidDnsRecordAddress(confirmed_name));
endpoint =
UpdateDomain(confirmed_name, requested_instance, *network_config_);
}
for (const auto& mdns_record : GetDnsRecords(endpoint)) {
TRACE_SCOPED(TraceCategory::kDiscovery, "mdns.RegisterRecord");
Error result = mdns_publisher_->RegisterRecord(mdns_record);
if (!result.ok()) {
reporting_client_->OnRecoverableError(
Error(Error::Code::kRecordPublicationError, result.ToString()));
}
}
auto pair = published_instances_.emplace(std::move(requested_instance),
std::move(endpoint));
client->OnEndpointClaimed(pair.first->first, pair.first->second);
}
} // namespace discovery
} // namespace openscreen