You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
474 lines
18 KiB
474 lines
18 KiB
// Copyright 2019 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "discovery/dnssd/impl/querier_impl.h"
|
|
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "discovery/common/reporting_client.h"
|
|
#include "discovery/dnssd/impl/conversion_layer.h"
|
|
#include "discovery/dnssd/impl/network_interface_config.h"
|
|
#include "platform/api/task_runner.h"
|
|
#include "util/osp_logging.h"
|
|
|
|
namespace openscreen {
|
|
namespace discovery {
|
|
namespace {
|
|
|
|
static constexpr char kLocalDomain[] = "local";
|
|
|
|
// Removes all error instances from the below records, and calls the log
|
|
// function on all errors present in |new_endpoints|. Input vectors are expected
|
|
// to be sorted in ascending order.
|
|
void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints,
|
|
std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints,
|
|
std::function<void(Error)> log) {
|
|
OSP_DCHECK(old_endpoints);
|
|
OSP_DCHECK(new_endpoints);
|
|
|
|
auto old_it = old_endpoints->begin();
|
|
auto new_it = new_endpoints->begin();
|
|
|
|
// Iterate across both vectors and log new errors in the process.
|
|
// NOTE: In sorted order, all errors will appear before all non-errors.
|
|
while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) {
|
|
ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it;
|
|
ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it;
|
|
|
|
if (new_ep.is_value()) {
|
|
break;
|
|
}
|
|
|
|
// If they are equal, the element is in both |old_endpoints| and
|
|
// |new_endpoints|, so skip it in both vectors.
|
|
if (old_ep == new_ep) {
|
|
old_it++;
|
|
new_it++;
|
|
continue;
|
|
}
|
|
|
|
// There's an error in |old_endpoints| not in |new_endpoints|, so skip it.
|
|
if (old_ep < new_ep) {
|
|
old_it++;
|
|
continue;
|
|
}
|
|
|
|
// There's an error in |new_endpoints| not in |old_endpoints|, so it's a new
|
|
// error from the applied changes. Log it.
|
|
log(std::move(new_ep.error()));
|
|
new_it++;
|
|
}
|
|
|
|
// Skip all remaining errors in the old vector.
|
|
for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) {
|
|
}
|
|
|
|
// Log all errors remaining in the new vector.
|
|
for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) {
|
|
log(std::move(new_it->error()));
|
|
}
|
|
|
|
// Erase errors.
|
|
old_endpoints->erase(old_endpoints->begin(), old_it);
|
|
new_endpoints->erase(new_endpoints->begin(), new_it);
|
|
}
|
|
|
|
// Returns a vector containing the value of each ErrorOr<> instance provided.
|
|
// All ErrorOr<> values are expected to be non-errors.
|
|
std::vector<DnsSdInstanceEndpoint> GetValues(
|
|
std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) {
|
|
std::vector<DnsSdInstanceEndpoint> results;
|
|
results.reserve(endpoints.size());
|
|
for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) {
|
|
results.push_back(std::move(endpoint.value()));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
|
|
const absl::optional<DnsSdInstanceEndpoint>& second) {
|
|
if (!first.has_value() || !second.has_value()) {
|
|
return !first.has_value() && !second.has_value();
|
|
}
|
|
|
|
// In the remaining case, both |first| and |second| must be values.
|
|
const DnsSdInstanceEndpoint& a = first.value();
|
|
const DnsSdInstanceEndpoint& b = second.value();
|
|
|
|
// All endpoints from this querier should have the same network interface
|
|
// because the querier is only associated with a single network interface.
|
|
OSP_DCHECK_EQ(a.network_interface(), b.network_interface());
|
|
|
|
// Function returns true if first < second.
|
|
return a.instance_id() == b.instance_id() &&
|
|
a.service_id() == b.service_id() && a.domain_id() == b.domain_id();
|
|
}
|
|
|
|
bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
|
|
const absl::optional<DnsSdInstanceEndpoint>& second) {
|
|
return !IsEqualOrUpdate(first, second);
|
|
}
|
|
|
|
// Calculates the created, updated, and deleted elements using the provided
|
|
// sets, appending these values to the provided vectors. Each of the input
|
|
// vectors is expected to contain only elements such that
|
|
// |element|.is_error() == false. Additionally, input vectors are expected to
|
|
// be sorted in ascending order.
|
|
//
|
|
// NOTE: A lot of operations are used to do this, but each is only O(n) so the
|
|
// resulting algorithm is still fast.
|
|
void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,
|
|
std::vector<DnsSdInstanceEndpoint> new_endpoints,
|
|
std::vector<DnsSdInstanceEndpoint>* created_out,
|
|
std::vector<DnsSdInstanceEndpoint>* updated_out,
|
|
std::vector<DnsSdInstanceEndpoint>* deleted_out) {
|
|
OSP_DCHECK(created_out);
|
|
OSP_DCHECK(updated_out);
|
|
OSP_DCHECK(deleted_out);
|
|
|
|
// Use set difference with default operators to find the elements present in
|
|
// one list but not the others.
|
|
//
|
|
// NOTE: Because absl::optional<...> types are used here and below, calls to
|
|
// the ctor and dtor for empty elements are no-ops.
|
|
const int total_count = old_endpoints.size() + new_endpoints.size();
|
|
|
|
// This is the set of elements that aren't in the old endpoints, meaning the
|
|
// old endpoint either didn't exist or had different TXT / Address / etc..
|
|
std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated(
|
|
total_count);
|
|
auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(),
|
|
old_endpoints.begin(), old_endpoints.end(),
|
|
created_or_updated.begin());
|
|
created_or_updated.erase(new_end, created_or_updated.end());
|
|
|
|
// This is the set of elements that are only in the old endpoints, similar to
|
|
// the above.
|
|
std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated(
|
|
total_count);
|
|
new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(),
|
|
new_endpoints.begin(), new_endpoints.end(),
|
|
deleted_or_updated.begin());
|
|
deleted_or_updated.erase(new_end, deleted_or_updated.end());
|
|
|
|
// Next, find the elements which were updated.
|
|
const size_t max_count =
|
|
std::max(created_or_updated.size(), deleted_or_updated.size());
|
|
std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count);
|
|
new_end = std::set_intersection(
|
|
created_or_updated.begin(), created_or_updated.end(),
|
|
deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
|
|
IsNotEqualOrUpdate);
|
|
updated.erase(new_end, updated.end());
|
|
|
|
// Use the updated elements to find all created and deleted elements.
|
|
std::vector<absl::optional<DnsSdInstanceEndpoint>> created(
|
|
created_or_updated.size());
|
|
new_end = std::set_difference(
|
|
created_or_updated.begin(), created_or_updated.end(), updated.begin(),
|
|
updated.end(), created.begin(), IsNotEqualOrUpdate);
|
|
created.erase(new_end, created.end());
|
|
|
|
std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted(
|
|
deleted_or_updated.size());
|
|
new_end = std::set_difference(
|
|
deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
|
|
updated.end(), deleted.begin(), IsNotEqualOrUpdate);
|
|
deleted.erase(new_end, deleted.end());
|
|
|
|
// Return the calculated elements back to the caller in the output variables.
|
|
created_out->reserve(created.size());
|
|
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) {
|
|
OSP_DCHECK(endpoint.has_value());
|
|
created_out->push_back(std::move(endpoint.value()));
|
|
}
|
|
|
|
updated_out->reserve(updated.size());
|
|
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) {
|
|
OSP_DCHECK(endpoint.has_value());
|
|
updated_out->push_back(std::move(endpoint.value()));
|
|
}
|
|
|
|
deleted_out->reserve(deleted.size());
|
|
for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) {
|
|
OSP_DCHECK(endpoint.has_value());
|
|
deleted_out->push_back(std::move(endpoint.value()));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
|
|
TaskRunner* task_runner,
|
|
ReportingClient* reporting_client,
|
|
const NetworkInterfaceConfig* network_config)
|
|
: mdns_querier_(mdns_querier),
|
|
task_runner_(task_runner),
|
|
reporting_client_(reporting_client) {
|
|
OSP_DCHECK(mdns_querier_);
|
|
OSP_DCHECK(task_runner_);
|
|
|
|
OSP_DCHECK(network_config);
|
|
graph_ = DnsDataGraph::Create(network_config->network_interface());
|
|
}
|
|
|
|
QuerierImpl::~QuerierImpl() = default;
|
|
|
|
void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
|
|
OSP_DCHECK(callback);
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'";
|
|
|
|
// Start tracking the new callback
|
|
const ServiceKey key(service, kLocalDomain);
|
|
auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first;
|
|
it->second.push_back(callback);
|
|
|
|
const DomainName domain = key.GetName();
|
|
|
|
// If the associated service isn't tracked yet, start tracking it and start
|
|
// queries for the relevant PTR records.
|
|
if (!graph_->IsTracked(domain)) {
|
|
std::function<void(const DomainName&)> mdns_query(
|
|
[this, &domain](const DomainName& changed_domain) {
|
|
OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'";
|
|
mdns_querier_->StartQuery(changed_domain, DnsType::kANY,
|
|
DnsClass::kANY, this);
|
|
});
|
|
graph_->StartTracking(domain, std::move(mdns_query));
|
|
return;
|
|
}
|
|
|
|
// Else, it's already being tracked so fire creation callbacks for any already
|
|
// found service instances.
|
|
const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints =
|
|
graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain);
|
|
for (const auto& endpoint : endpoints) {
|
|
if (endpoint.is_value()) {
|
|
callback->OnEndpointCreated(endpoint.value());
|
|
}
|
|
}
|
|
}
|
|
|
|
void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
|
|
OSP_DCHECK(callback);
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'";
|
|
|
|
ServiceKey key(service, kLocalDomain);
|
|
const auto callbacks_it = callback_map_.find(key);
|
|
if (callbacks_it == callback_map_.end()) {
|
|
return;
|
|
}
|
|
|
|
std::vector<Callback*>& callbacks = callbacks_it->second;
|
|
const auto it = std::find(callbacks.begin(), callbacks.end(), callback);
|
|
if (it == callbacks.end()) {
|
|
return;
|
|
}
|
|
|
|
callbacks.erase(it);
|
|
if (callbacks.empty()) {
|
|
callback_map_.erase(callbacks_it);
|
|
|
|
ServiceKey key(service, kLocalDomain);
|
|
DomainName domain = key.GetName();
|
|
|
|
std::function<void(const DomainName&)> stop_mdns_query(
|
|
[this](const DomainName& changed_domain) {
|
|
OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString()
|
|
<< "'";
|
|
mdns_querier_->StopQuery(changed_domain, DnsType::kANY,
|
|
DnsClass::kANY, this);
|
|
});
|
|
graph_->StopTracking(domain, std::move(stop_mdns_query));
|
|
}
|
|
}
|
|
|
|
bool QuerierImpl::IsQueryRunning(const std::string& service) const {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
const ServiceKey key(service, kLocalDomain);
|
|
return graph_->IsTracked(key.GetName());
|
|
}
|
|
|
|
void QuerierImpl::ReinitializeQueries(const std::string& service) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
OSP_DVLOG << "Re-initializing query for service '" << service << "'";
|
|
|
|
const ServiceKey key(service, kLocalDomain);
|
|
const DomainName domain = key.GetName();
|
|
|
|
std::function<void(const DomainName&)> start_callback(
|
|
[this](const DomainName& domain) {
|
|
mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this);
|
|
});
|
|
std::function<void(const DomainName&)> stop_callback(
|
|
[this](const DomainName& domain) {
|
|
mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this);
|
|
});
|
|
graph_->StopTracking(domain, std::move(stop_callback));
|
|
|
|
// Restart top-level queries.
|
|
mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
|
|
|
|
graph_->StartTracking(domain, std::move(start_callback));
|
|
}
|
|
|
|
std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
|
|
const MdnsRecord& record,
|
|
RecordChangedEvent event) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
OSP_DVLOG << "Record " << record.ToString()
|
|
<< " has received change of type '" << event << "'";
|
|
|
|
std::function<void(Error)> log = [this](Error error) mutable {
|
|
reporting_client_->OnRecoverableError(
|
|
Error(Error::Code::kProcessReceivedRecordFailure));
|
|
};
|
|
|
|
// Get the details to use for calling CreateEndpoints(). Special case PTR
|
|
// records to optimize performance.
|
|
const DomainName& create_endpoints_domain =
|
|
record.dns_type() != DnsType::kPTR
|
|
? record.name()
|
|
: absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
|
|
const DnsDataGraph::DomainGroup create_endpoints_group =
|
|
record.dns_type() != DnsType::kPTR
|
|
? DnsDataGraph::GetDomainGroup(record)
|
|
: DnsDataGraph::DomainGroup::kSrvAndTxt;
|
|
|
|
// Get the current set of DnsSdInstanceEndpoints prior to this change. Special
|
|
// case PTR records to avoid iterating over unrelated child domains.
|
|
std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors =
|
|
graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
|
|
|
|
// Apply the changes, creating a list of all pending changes that should be
|
|
// applied afterwards.
|
|
ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error =
|
|
ApplyRecordChanges(record, event);
|
|
if (pending_changes_or_error.is_error()) {
|
|
OSP_DVLOG << "Failed to apply changes for " << record.dns_type()
|
|
<< " record change of type " << event << " with error "
|
|
<< pending_changes_or_error.error();
|
|
log(std::move(pending_changes_or_error.error()));
|
|
return {};
|
|
}
|
|
std::vector<PendingQueryChange>& pending_changes =
|
|
pending_changes_or_error.value();
|
|
|
|
// Get the new set of DnsSdInstanceEndpoints following this change.
|
|
std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors =
|
|
graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
|
|
|
|
// Return early if the resulting sets are equal. This will frequently be the
|
|
// case, especially when both sets are empty.
|
|
std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end());
|
|
std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end());
|
|
if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() &&
|
|
std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(),
|
|
new_endpoints_or_errors.begin())) {
|
|
return pending_changes;
|
|
}
|
|
|
|
// Log all errors and erase them.
|
|
ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors,
|
|
std::move(log));
|
|
const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size();
|
|
const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size();
|
|
std::vector<DnsSdInstanceEndpoint> old_endpoints =
|
|
GetValues(std::move(old_endpoints_or_errors));
|
|
std::vector<DnsSdInstanceEndpoint> new_endpoints =
|
|
GetValues(std::move(new_endpoints_or_errors));
|
|
OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count);
|
|
OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count);
|
|
|
|
// Calculate the changes and call callbacks.
|
|
//
|
|
// NOTE: As the input sets are expected to be small, the generated sets will
|
|
// also be small.
|
|
std::vector<DnsSdInstanceEndpoint> created;
|
|
std::vector<DnsSdInstanceEndpoint> updated;
|
|
std::vector<DnsSdInstanceEndpoint> deleted;
|
|
CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints),
|
|
&created, &updated, &deleted);
|
|
|
|
InvokeChangeCallbacks(std::move(created), std::move(updated),
|
|
std::move(deleted));
|
|
return pending_changes;
|
|
}
|
|
|
|
void QuerierImpl::InvokeChangeCallbacks(
|
|
std::vector<DnsSdInstanceEndpoint> created,
|
|
std::vector<DnsSdInstanceEndpoint> updated,
|
|
std::vector<DnsSdInstanceEndpoint> deleted) {
|
|
// Find an endpoint and use it to create the key, or return if there is none.
|
|
DnsSdInstanceEndpoint* some_endpoint;
|
|
if (!created.empty()) {
|
|
some_endpoint = &created.front();
|
|
} else if (!updated.empty()) {
|
|
some_endpoint = &updated.front();
|
|
} else if (!deleted.empty()) {
|
|
some_endpoint = &deleted.front();
|
|
} else {
|
|
return;
|
|
}
|
|
ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id());
|
|
|
|
// Find all callbacks.
|
|
auto it = callback_map_.find(key);
|
|
if (it == callback_map_.end()) {
|
|
return;
|
|
}
|
|
|
|
// Call relevant callbacks.
|
|
std::vector<Callback*>& callbacks = it->second;
|
|
for (Callback* callback : callbacks) {
|
|
for (const DnsSdInstanceEndpoint& endpoint : created) {
|
|
callback->OnEndpointCreated(endpoint);
|
|
}
|
|
for (const DnsSdInstanceEndpoint& endpoint : updated) {
|
|
callback->OnEndpointUpdated(endpoint);
|
|
}
|
|
for (const DnsSdInstanceEndpoint& endpoint : deleted) {
|
|
callback->OnEndpointDeleted(endpoint);
|
|
}
|
|
}
|
|
}
|
|
|
|
ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges(
|
|
const MdnsRecord& record,
|
|
RecordChangedEvent event) {
|
|
std::vector<PendingQueryChange> pending_changes;
|
|
std::function<void(DomainName)> creation_callback(
|
|
[this, &pending_changes](DomainName domain) mutable {
|
|
pending_changes.push_back({std::move(domain), DnsType::kANY,
|
|
DnsClass::kANY, this,
|
|
PendingQueryChange::kStartQuery});
|
|
});
|
|
std::function<void(DomainName)> deletion_callback(
|
|
[this, &pending_changes](DomainName domain) mutable {
|
|
pending_changes.push_back({std::move(domain), DnsType::kANY,
|
|
DnsClass::kANY, this,
|
|
PendingQueryChange::kStopQuery});
|
|
});
|
|
Error result =
|
|
graph_->ApplyDataRecordChange(record, event, std::move(creation_callback),
|
|
std::move(deletion_callback));
|
|
if (!result.ok()) {
|
|
return result;
|
|
}
|
|
|
|
return pending_changes;
|
|
}
|
|
|
|
} // namespace discovery
|
|
} // namespace openscreen
|