// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/dnssd/impl/querier_impl.h" #include #include #include #include #include "discovery/common/reporting_client.h" #include "discovery/dnssd/impl/conversion_layer.h" #include "discovery/dnssd/impl/network_interface_config.h" #include "platform/api/task_runner.h" #include "util/osp_logging.h" namespace openscreen { namespace discovery { namespace { static constexpr char kLocalDomain[] = "local"; // Removes all error instances from the below records, and calls the log // function on all errors present in |new_endpoints|. Input vectors are expected // to be sorted in ascending order. void ProcessErrors(std::vector>* old_endpoints, std::vector>* new_endpoints, std::function log) { OSP_DCHECK(old_endpoints); OSP_DCHECK(new_endpoints); auto old_it = old_endpoints->begin(); auto new_it = new_endpoints->begin(); // Iterate across both vectors and log new errors in the process. // NOTE: In sorted order, all errors will appear before all non-errors. while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) { ErrorOr& old_ep = *old_it; ErrorOr& new_ep = *new_it; if (new_ep.is_value()) { break; } // If they are equal, the element is in both |old_endpoints| and // |new_endpoints|, so skip it in both vectors. if (old_ep == new_ep) { old_it++; new_it++; continue; } // There's an error in |old_endpoints| not in |new_endpoints|, so skip it. if (old_ep < new_ep) { old_it++; continue; } // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new // error from the applied changes. Log it. log(std::move(new_ep.error())); new_it++; } // Skip all remaining errors in the old vector. for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) { } // Log all errors remaining in the new vector. for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) { log(std::move(new_it->error())); } // Erase errors. old_endpoints->erase(old_endpoints->begin(), old_it); new_endpoints->erase(new_endpoints->begin(), new_it); } // Returns a vector containing the value of each ErrorOr<> instance provided. // All ErrorOr<> values are expected to be non-errors. std::vector GetValues( std::vector> endpoints) { std::vector results; results.reserve(endpoints.size()); for (ErrorOr& endpoint : endpoints) { results.push_back(std::move(endpoint.value())); } return results; } bool IsEqualOrUpdate(const absl::optional& first, const absl::optional& second) { if (!first.has_value() || !second.has_value()) { return !first.has_value() && !second.has_value(); } // In the remaining case, both |first| and |second| must be values. const DnsSdInstanceEndpoint& a = first.value(); const DnsSdInstanceEndpoint& b = second.value(); // All endpoints from this querier should have the same network interface // because the querier is only associated with a single network interface. OSP_DCHECK_EQ(a.network_interface(), b.network_interface()); // Function returns true if first < second. return a.instance_id() == b.instance_id() && a.service_id() == b.service_id() && a.domain_id() == b.domain_id(); } bool IsNotEqualOrUpdate(const absl::optional& first, const absl::optional& second) { return !IsEqualOrUpdate(first, second); } // Calculates the created, updated, and deleted elements using the provided // sets, appending these values to the provided vectors. Each of the input // vectors is expected to contain only elements such that // |element|.is_error() == false. Additionally, input vectors are expected to // be sorted in ascending order. // // NOTE: A lot of operations are used to do this, but each is only O(n) so the // resulting algorithm is still fast. void CalculateChangeSets(std::vector old_endpoints, std::vector new_endpoints, std::vector* created_out, std::vector* updated_out, std::vector* deleted_out) { OSP_DCHECK(created_out); OSP_DCHECK(updated_out); OSP_DCHECK(deleted_out); // Use set difference with default operators to find the elements present in // one list but not the others. // // NOTE: Because absl::optional<...> types are used here and below, calls to // the ctor and dtor for empty elements are no-ops. const int total_count = old_endpoints.size() + new_endpoints.size(); // This is the set of elements that aren't in the old endpoints, meaning the // old endpoint either didn't exist or had different TXT / Address / etc.. std::vector> created_or_updated( total_count); auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(), old_endpoints.begin(), old_endpoints.end(), created_or_updated.begin()); created_or_updated.erase(new_end, created_or_updated.end()); // This is the set of elements that are only in the old endpoints, similar to // the above. std::vector> deleted_or_updated( total_count); new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(), new_endpoints.begin(), new_endpoints.end(), deleted_or_updated.begin()); deleted_or_updated.erase(new_end, deleted_or_updated.end()); // Next, find the elements which were updated. const size_t max_count = std::max(created_or_updated.size(), deleted_or_updated.size()); std::vector> updated(max_count); new_end = std::set_intersection( created_or_updated.begin(), created_or_updated.end(), deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), IsNotEqualOrUpdate); updated.erase(new_end, updated.end()); // Use the updated elements to find all created and deleted elements. std::vector> created( created_or_updated.size()); new_end = std::set_difference( created_or_updated.begin(), created_or_updated.end(), updated.begin(), updated.end(), created.begin(), IsNotEqualOrUpdate); created.erase(new_end, created.end()); std::vector> deleted( deleted_or_updated.size()); new_end = std::set_difference( deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(), updated.end(), deleted.begin(), IsNotEqualOrUpdate); deleted.erase(new_end, deleted.end()); // Return the calculated elements back to the caller in the output variables. created_out->reserve(created.size()); for (absl::optional& endpoint : created) { OSP_DCHECK(endpoint.has_value()); created_out->push_back(std::move(endpoint.value())); } updated_out->reserve(updated.size()); for (absl::optional& endpoint : updated) { OSP_DCHECK(endpoint.has_value()); updated_out->push_back(std::move(endpoint.value())); } deleted_out->reserve(deleted.size()); for (absl::optional& endpoint : deleted) { OSP_DCHECK(endpoint.has_value()); deleted_out->push_back(std::move(endpoint.value())); } } } // namespace QuerierImpl::QuerierImpl(MdnsService* mdns_querier, TaskRunner* task_runner, ReportingClient* reporting_client, const NetworkInterfaceConfig* network_config) : mdns_querier_(mdns_querier), task_runner_(task_runner), reporting_client_(reporting_client) { OSP_DCHECK(mdns_querier_); OSP_DCHECK(task_runner_); OSP_DCHECK(network_config); graph_ = DnsDataGraph::Create(network_config->network_interface()); } QuerierImpl::~QuerierImpl() = default; void QuerierImpl::StartQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'"; // Start tracking the new callback const ServiceKey key(service, kLocalDomain); auto it = callback_map_.emplace(key, std::vector{}).first; it->second.push_back(callback); const DomainName domain = key.GetName(); // If the associated service isn't tracked yet, start tracking it and start // queries for the relevant PTR records. if (!graph_->IsTracked(domain)) { std::function mdns_query( [this, &domain](const DomainName& changed_domain) { OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'"; mdns_querier_->StartQuery(changed_domain, DnsType::kANY, DnsClass::kANY, this); }); graph_->StartTracking(domain, std::move(mdns_query)); return; } // Else, it's already being tracked so fire creation callbacks for any already // found service instances. const std::vector> endpoints = graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain); for (const auto& endpoint : endpoints) { if (endpoint.is_value()) { callback->OnEndpointCreated(endpoint.value()); } } } void QuerierImpl::StopQuery(const std::string& service, Callback* callback) { OSP_DCHECK(callback); OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'"; ServiceKey key(service, kLocalDomain); const auto callbacks_it = callback_map_.find(key); if (callbacks_it == callback_map_.end()) { return; } std::vector& callbacks = callbacks_it->second; const auto it = std::find(callbacks.begin(), callbacks.end(), callback); if (it == callbacks.end()) { return; } callbacks.erase(it); if (callbacks.empty()) { callback_map_.erase(callbacks_it); ServiceKey key(service, kLocalDomain); DomainName domain = key.GetName(); std::function stop_mdns_query( [this](const DomainName& changed_domain) { OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString() << "'"; mdns_querier_->StopQuery(changed_domain, DnsType::kANY, DnsClass::kANY, this); }); graph_->StopTracking(domain, std::move(stop_mdns_query)); } } bool QuerierImpl::IsQueryRunning(const std::string& service) const { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); const ServiceKey key(service, kLocalDomain); return graph_->IsTracked(key.GetName()); } void QuerierImpl::ReinitializeQueries(const std::string& service) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DVLOG << "Re-initializing query for service '" << service << "'"; const ServiceKey key(service, kLocalDomain); const DomainName domain = key.GetName(); std::function start_callback( [this](const DomainName& domain) { mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this); }); std::function stop_callback( [this](const DomainName& domain) { mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this); }); graph_->StopTracking(domain, std::move(stop_callback)); // Restart top-level queries. mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name); graph_->StartTracking(domain, std::move(start_callback)); } std::vector QuerierImpl::OnRecordChanged( const MdnsRecord& record, RecordChangedEvent event) { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); OSP_DVLOG << "Record " << record.ToString() << " has received change of type '" << event << "'"; std::function log = [this](Error error) mutable { reporting_client_->OnRecoverableError( Error(Error::Code::kProcessReceivedRecordFailure)); }; // Get the details to use for calling CreateEndpoints(). Special case PTR // records to optimize performance. const DomainName& create_endpoints_domain = record.dns_type() != DnsType::kPTR ? record.name() : absl::get(record.rdata()).ptr_domain(); const DnsDataGraph::DomainGroup create_endpoints_group = record.dns_type() != DnsType::kPTR ? DnsDataGraph::GetDomainGroup(record) : DnsDataGraph::DomainGroup::kSrvAndTxt; // Get the current set of DnsSdInstanceEndpoints prior to this change. Special // case PTR records to avoid iterating over unrelated child domains. std::vector> old_endpoints_or_errors = graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); // Apply the changes, creating a list of all pending changes that should be // applied afterwards. ErrorOr> pending_changes_or_error = ApplyRecordChanges(record, event); if (pending_changes_or_error.is_error()) { OSP_DVLOG << "Failed to apply changes for " << record.dns_type() << " record change of type " << event << " with error " << pending_changes_or_error.error(); log(std::move(pending_changes_or_error.error())); return {}; } std::vector& pending_changes = pending_changes_or_error.value(); // Get the new set of DnsSdInstanceEndpoints following this change. std::vector> new_endpoints_or_errors = graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain); // Return early if the resulting sets are equal. This will frequently be the // case, especially when both sets are empty. std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end()); std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end()); if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() && std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(), new_endpoints_or_errors.begin())) { return pending_changes; } // Log all errors and erase them. ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors, std::move(log)); const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size(); const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size(); std::vector old_endpoints = GetValues(std::move(old_endpoints_or_errors)); std::vector new_endpoints = GetValues(std::move(new_endpoints_or_errors)); OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count); OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count); // Calculate the changes and call callbacks. // // NOTE: As the input sets are expected to be small, the generated sets will // also be small. std::vector created; std::vector updated; std::vector deleted; CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints), &created, &updated, &deleted); InvokeChangeCallbacks(std::move(created), std::move(updated), std::move(deleted)); return pending_changes; } void QuerierImpl::InvokeChangeCallbacks( std::vector created, std::vector updated, std::vector deleted) { // Find an endpoint and use it to create the key, or return if there is none. DnsSdInstanceEndpoint* some_endpoint; if (!created.empty()) { some_endpoint = &created.front(); } else if (!updated.empty()) { some_endpoint = &updated.front(); } else if (!deleted.empty()) { some_endpoint = &deleted.front(); } else { return; } ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id()); // Find all callbacks. auto it = callback_map_.find(key); if (it == callback_map_.end()) { return; } // Call relevant callbacks. std::vector& callbacks = it->second; for (Callback* callback : callbacks) { for (const DnsSdInstanceEndpoint& endpoint : created) { callback->OnEndpointCreated(endpoint); } for (const DnsSdInstanceEndpoint& endpoint : updated) { callback->OnEndpointUpdated(endpoint); } for (const DnsSdInstanceEndpoint& endpoint : deleted) { callback->OnEndpointDeleted(endpoint); } } } ErrorOr> QuerierImpl::ApplyRecordChanges( const MdnsRecord& record, RecordChangedEvent event) { std::vector pending_changes; std::function creation_callback( [this, &pending_changes](DomainName domain) mutable { pending_changes.push_back({std::move(domain), DnsType::kANY, DnsClass::kANY, this, PendingQueryChange::kStartQuery}); }); std::function deletion_callback( [this, &pending_changes](DomainName domain) mutable { pending_changes.push_back({std::move(domain), DnsType::kANY, DnsClass::kANY, this, PendingQueryChange::kStopQuery}); }); Error result = graph_->ApplyDataRecordChange(record, event, std::move(creation_callback), std::move(deletion_callback)); if (!result.ok()) { return result; } return pending_changes; } } // namespace discovery } // namespace openscreen