You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
878 lines
31 KiB
878 lines
31 KiB
// Copyright 2019 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "discovery/mdns/mdns_querier.h"
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <bitset>
|
|
#include <memory>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "discovery/common/config.h"
|
|
#include "discovery/common/reporting_client.h"
|
|
#include "discovery/mdns/mdns_random.h"
|
|
#include "discovery/mdns/mdns_receiver.h"
|
|
#include "discovery/mdns/mdns_sender.h"
|
|
#include "discovery/mdns/public/mdns_constants.h"
|
|
|
|
namespace openscreen {
|
|
namespace discovery {
|
|
namespace {
|
|
|
|
constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
|
|
DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
|
|
|
|
bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
|
|
if (record.dns_type() != DnsType::kNSEC) {
|
|
return false;
|
|
}
|
|
|
|
const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
|
|
|
|
// RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
|
|
// record to indicate this is an mDNS NSEC record rather than a traditional
|
|
// DNS NSEC record.
|
|
if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
|
|
nsec.types().end()) {
|
|
return false;
|
|
}
|
|
|
|
return std::find_if(nsec.types().begin(), nsec.types().end(),
|
|
[type](DnsType stored_type) {
|
|
return stored_type == type ||
|
|
stored_type == DnsType::kANY;
|
|
}) != nsec.types().end();
|
|
}
|
|
|
|
struct HashDnsType {
|
|
inline size_t operator()(DnsType type) const {
|
|
return static_cast<size_t>(type);
|
|
}
|
|
};
|
|
|
|
// Helper used for sorting MDNS records. This function guarantees the following:
|
|
// - All MdnsRecords with the same name appear adjacent to each-other.
|
|
// - An NSEC record with a given name appears before all other records with the
|
|
// same name.
|
|
bool CompareRecordByNameAndType(const MdnsRecord& first,
|
|
const MdnsRecord& second) {
|
|
if (first.name() != second.name()) {
|
|
return first.name() < second.name();
|
|
}
|
|
|
|
if ((first.dns_type() == DnsType::kNSEC) !=
|
|
(second.dns_type() == DnsType::kNSEC)) {
|
|
return first.dns_type() == DnsType::kNSEC;
|
|
}
|
|
|
|
return first < second;
|
|
}
|
|
|
|
class DnsTypeBitSet {
|
|
public:
|
|
// Returns whether any types are currently stored in this data structure.
|
|
bool IsEmpty() { return !elements_.any(); }
|
|
|
|
// Attempts to insert the given type into this data structure. Returns
|
|
// true iff the type was not already present.
|
|
bool Insert(DnsType type) {
|
|
uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
|
|
bool was_set = elements_.test(bit);
|
|
elements_.set(bit);
|
|
return !was_set;
|
|
}
|
|
|
|
// Iterates over all members of the provided container, inserting each
|
|
// DnsType contained within to this instance. Returns true iff any element
|
|
// inserted was not already present in this instance.
|
|
template <typename Container>
|
|
bool Insert(const Container& container) {
|
|
bool has_element_been_inserted = false;
|
|
for (DnsType type : container) {
|
|
has_element_been_inserted |= Insert(type);
|
|
}
|
|
return has_element_been_inserted;
|
|
}
|
|
|
|
// Attempts to remove the given type from this data structure. Returns true
|
|
// iff the type was present prior to this call.
|
|
bool Remove(DnsType type) {
|
|
if (IsEmpty()) {
|
|
return false;
|
|
} else if (type == DnsType::kANY) {
|
|
elements_.reset();
|
|
return true;
|
|
}
|
|
|
|
uint16_t bit = static_cast<uint16_t>(type);
|
|
bool was_set = elements_.test(bit);
|
|
elements_.reset(bit);
|
|
return was_set;
|
|
}
|
|
|
|
// Returns the DnsTypes currently stored in this data structure.
|
|
std::vector<DnsType> GetTypes() const {
|
|
if (elements_.test(0)) {
|
|
return {DnsType::kANY};
|
|
}
|
|
|
|
std::vector<DnsType> types;
|
|
for (DnsType type : kSupportedDnsTypes) {
|
|
if (type == DnsType::kANY) {
|
|
continue;
|
|
}
|
|
|
|
uint16_t cast_int = static_cast<uint16_t>(type);
|
|
if (elements_.test(cast_int)) {
|
|
types.push_back(type);
|
|
}
|
|
}
|
|
return types;
|
|
}
|
|
|
|
private:
|
|
std::bitset<64> elements_;
|
|
};
|
|
|
|
// Modifies |records| such that no NSEC record signifies the nonexistance of a
|
|
// record which is also present in the same message. Order of the input vector
|
|
// is NOT preserved.
|
|
// NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
|
|
// be modified.
|
|
// TODO(b/170353378): Break this logic into a separate processing module between
|
|
// the MdnsReader and the MdnsQuerier.
|
|
void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
|
|
// Sort the records so NSEC records are first so that only one iteration
|
|
// through all records is needed.
|
|
std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
|
|
|
|
// The set of NSEC records that need to be removed from |records|. This can't
|
|
// be done as part of the below loop because it would invalidate the iterator
|
|
// that's still being used.
|
|
std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
|
|
|
|
// Process all elements.
|
|
for (auto it = records->begin(); it != records->end();) {
|
|
if (it->dns_type() != DnsType::kNSEC) {
|
|
it++;
|
|
continue;
|
|
}
|
|
|
|
// Track whether the current NSEC record in the input vector has been
|
|
// modified by some step of this algorithm, be that merging with another
|
|
// record, removing a DnsType, or any other modification.
|
|
bool has_changed = false;
|
|
|
|
// The types for the new record to create, if |has_changed|.
|
|
const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
|
|
DnsTypeBitSet types;
|
|
for (DnsType type : nsec_rdata.types()) {
|
|
types.Insert(type);
|
|
}
|
|
auto nsec = it;
|
|
it++;
|
|
|
|
// Combine multiple NSECs to simplify the following code. This probably
|
|
// won't happen, but the RFC doesn't exclude the possibility, so account for
|
|
// it. Define the TTL of this new NSEC record created by this merge process
|
|
// to be the minimum of all merged NSEC records.
|
|
std::chrono::seconds new_ttl = nsec->ttl();
|
|
while (it != records->end() && it->name() == nsec->name() &&
|
|
it->dns_type() == DnsType::kNSEC) {
|
|
has_changed |=
|
|
types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
|
|
new_ttl = std::min(new_ttl, it->ttl());
|
|
it = records->erase(it);
|
|
}
|
|
|
|
// Remove any types associated with a known record type.
|
|
for (; it != records->end() && it->name() == nsec->name(); it++) {
|
|
OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
|
|
has_changed |= types.Remove(it->dns_type());
|
|
}
|
|
|
|
// Modify the stored NSEC record, if needed.
|
|
if (has_changed && types.IsEmpty()) {
|
|
nsecs_to_delete.push_back(nsec);
|
|
} else if (has_changed) {
|
|
NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
|
|
types.GetTypes());
|
|
*nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
|
|
nsec->record_type(), new_ttl, std::move(new_rdata));
|
|
}
|
|
}
|
|
|
|
// Erase invalid NSEC records. Go backwards to avoid invalidating the
|
|
// remaining iterators.
|
|
for (auto erase_it = nsecs_to_delete.rbegin();
|
|
erase_it != nsecs_to_delete.rend(); erase_it++) {
|
|
records->erase(*erase_it);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
|
|
MdnsQuerier* querier,
|
|
MdnsSender* sender,
|
|
MdnsRandom* random_delay,
|
|
TaskRunner* task_runner,
|
|
ClockNowFunctionPtr now_function,
|
|
ReportingClient* reporting_client,
|
|
const Config& config)
|
|
: querier_(querier),
|
|
sender_(sender),
|
|
random_delay_(random_delay),
|
|
task_runner_(task_runner),
|
|
now_function_(now_function),
|
|
reporting_client_(reporting_client),
|
|
config_(config) {
|
|
OSP_DCHECK(sender_);
|
|
OSP_DCHECK(random_delay_);
|
|
OSP_DCHECK(task_runner_);
|
|
OSP_DCHECK(reporting_client_);
|
|
OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
|
|
}
|
|
|
|
std::vector<std::reference_wrapper<const MdnsRecordTracker>>
|
|
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
|
|
return Find(name, DnsType::kANY, DnsClass::kANY);
|
|
}
|
|
|
|
std::vector<std::reference_wrapper<const MdnsRecordTracker>>
|
|
MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class) {
|
|
std::vector<RecordTrackerConstRef> results;
|
|
auto pair = records_.equal_range(name);
|
|
for (auto it = pair.first; it != pair.second; it++) {
|
|
const MdnsRecordTracker& tracker = *it->second;
|
|
if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
|
|
(dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
|
|
results.push_back(std::cref(tracker));
|
|
}
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
|
|
TrackerApplicableCheck check) {
|
|
auto pair = records_.equal_range(domain);
|
|
int count = 0;
|
|
for (RecordMap::iterator it = pair.first; it != pair.second;) {
|
|
if (check(*it->second)) {
|
|
lru_order_.erase(it->second);
|
|
it = records_.erase(it);
|
|
count++;
|
|
} else {
|
|
it++;
|
|
}
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
|
|
const DomainName& domain,
|
|
TrackerApplicableCheck check) {
|
|
auto pair = records_.equal_range(domain);
|
|
int count = 0;
|
|
for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
|
|
if (check(*it->second)) {
|
|
MoveToEnd(it);
|
|
it->second->ExpireSoon();
|
|
count++;
|
|
}
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
|
|
TrackerApplicableCheck check) {
|
|
return Update(record, check, [](const MdnsRecordTracker& t) {});
|
|
}
|
|
|
|
int MdnsQuerier::RecordTrackerLruCache::Update(
|
|
const MdnsRecord& record,
|
|
TrackerApplicableCheck check,
|
|
TrackerChangeCallback on_rdata_update) {
|
|
auto pair = records_.equal_range(record.name());
|
|
int count = 0;
|
|
for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
|
|
if (check(*it->second)) {
|
|
auto result = it->second->Update(record);
|
|
|
|
if (result.is_error()) {
|
|
reporting_client_->OnRecoverableError(
|
|
Error(Error::Code::kUpdateReceivedRecordFailure,
|
|
result.error().ToString()));
|
|
continue;
|
|
}
|
|
|
|
count++;
|
|
if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
|
|
it->second->ExpireSoon();
|
|
MoveToEnd(it);
|
|
} else {
|
|
MoveToBeginning(it);
|
|
if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
|
|
on_rdata_update(*it->second);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
|
|
MdnsRecord record,
|
|
DnsType dns_type) {
|
|
auto expiration_callback = [this](const MdnsRecordTracker* tracker,
|
|
const MdnsRecord& record) {
|
|
querier_->OnRecordExpired(tracker, record);
|
|
};
|
|
|
|
while (lru_order_.size() >=
|
|
static_cast<size_t>(config_.querier_max_records_cached)) {
|
|
// This call erases one of the tracked records.
|
|
OSP_DVLOG << "Maximum cacheable record count exceeded ("
|
|
<< config_.querier_max_records_cached << ")";
|
|
lru_order_.back().ExpireNow();
|
|
}
|
|
|
|
auto name = record.name();
|
|
lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
|
|
now_function_, random_delay_,
|
|
std::move(expiration_callback));
|
|
records_.emplace(std::move(name), lru_order_.begin());
|
|
|
|
return lru_order_.front();
|
|
}
|
|
|
|
void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
|
|
MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
|
|
lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
|
|
it->second = lru_order_.begin();
|
|
}
|
|
|
|
void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
|
|
MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
|
|
lru_order_.splice(lru_order_.end(), lru_order_, it->second);
|
|
it->second = --lru_order_.end();
|
|
}
|
|
|
|
MdnsQuerier::MdnsQuerier(MdnsSender* sender,
|
|
MdnsReceiver* receiver,
|
|
TaskRunner* task_runner,
|
|
ClockNowFunctionPtr now_function,
|
|
MdnsRandom* random_delay,
|
|
ReportingClient* reporting_client,
|
|
Config config)
|
|
: sender_(sender),
|
|
receiver_(receiver),
|
|
task_runner_(task_runner),
|
|
now_function_(now_function),
|
|
random_delay_(random_delay),
|
|
reporting_client_(reporting_client),
|
|
config_(std::move(config)),
|
|
records_(this,
|
|
sender_,
|
|
random_delay_,
|
|
task_runner_,
|
|
now_function_,
|
|
reporting_client_,
|
|
config_) {
|
|
OSP_DCHECK(sender_);
|
|
OSP_DCHECK(receiver_);
|
|
OSP_DCHECK(task_runner_);
|
|
OSP_DCHECK(now_function_);
|
|
OSP_DCHECK(random_delay_);
|
|
OSP_DCHECK(reporting_client_);
|
|
|
|
receiver_->AddResponseCallback(this);
|
|
}
|
|
|
|
MdnsQuerier::~MdnsQuerier() {
|
|
receiver_->RemoveResponseCallback(this);
|
|
}
|
|
|
|
// NOTE: The code below is range loops instead of std:find_if, for better
|
|
// readability, brevity and homogeneity. Using std::find_if results in a few
|
|
// more lines of code, readability suffers from extra lambdas.
|
|
|
|
void MdnsQuerier::StartQuery(const DomainName& name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
MdnsRecordChangedCallback* callback) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
OSP_DCHECK(callback);
|
|
OSP_DCHECK(CanBeQueried(dns_type));
|
|
|
|
// Add a new callback if haven't seen it before
|
|
auto callbacks_it = callbacks_.equal_range(name);
|
|
for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
|
|
const CallbackInfo& callback_info = entry->second;
|
|
if (dns_type == callback_info.dns_type &&
|
|
dns_class == callback_info.dns_class &&
|
|
callback == callback_info.callback) {
|
|
// Already have this callback
|
|
return;
|
|
}
|
|
}
|
|
callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
|
|
|
|
// Notify the new callback with previously cached records.
|
|
// NOTE: In the future, could allow callers to fetch cached records after
|
|
// adding a callback, for example to prime the UI.
|
|
std::vector<PendingQueryChange> pending_changes;
|
|
const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
|
|
records_.Find(name, dns_type, dns_class);
|
|
for (const MdnsRecordTracker& tracker : trackers) {
|
|
if (!tracker.is_negative_response()) {
|
|
MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
|
|
tracker.record_type(), tracker.ttl(),
|
|
tracker.rdata());
|
|
std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
|
|
std::move(stored_record), RecordChangedEvent::kCreated);
|
|
pending_changes.insert(pending_changes.end(), new_changes.begin(),
|
|
new_changes.end());
|
|
}
|
|
}
|
|
|
|
// Add a new question if haven't seen it before
|
|
auto questions_it = questions_.equal_range(name);
|
|
const bool is_question_already_tracked =
|
|
std::find_if(questions_it.first, questions_it.second,
|
|
[dns_type, dns_class](const auto& entry) {
|
|
const MdnsQuestion& tracked_question =
|
|
entry.second->question();
|
|
return dns_type == tracked_question.dns_type() &&
|
|
dns_class == tracked_question.dns_class();
|
|
}) != questions_it.second;
|
|
if (!is_question_already_tracked) {
|
|
AddQuestion(
|
|
MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
|
|
}
|
|
|
|
// Apply any pending changes from the OnRecordChanged() callbacks.
|
|
ApplyPendingChanges(std::move(pending_changes));
|
|
}
|
|
|
|
void MdnsQuerier::StopQuery(const DomainName& name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
MdnsRecordChangedCallback* callback) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
OSP_DCHECK(callback);
|
|
|
|
if (!CanBeQueried(dns_type)) {
|
|
return;
|
|
}
|
|
|
|
// Find and remove the callback.
|
|
int callbacks_for_key = 0;
|
|
auto callbacks_it = callbacks_.equal_range(name);
|
|
for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
|
|
const CallbackInfo& callback_info = entry->second;
|
|
if (dns_type == callback_info.dns_type &&
|
|
dns_class == callback_info.dns_class) {
|
|
if (callback == callback_info.callback) {
|
|
entry = callbacks_.erase(entry);
|
|
} else {
|
|
++callbacks_for_key;
|
|
++entry;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Exit if there are still callbacks registered for DomainName + DnsType +
|
|
// DnsClass
|
|
if (callbacks_for_key > 0) {
|
|
return;
|
|
}
|
|
|
|
// Find and delete a question that does not have any associated callbacks
|
|
auto questions_it = questions_.equal_range(name);
|
|
for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
|
|
const MdnsQuestion& tracked_question = entry->second->question();
|
|
if (dns_type == tracked_question.dns_type() &&
|
|
dns_class == tracked_question.dns_class()) {
|
|
questions_.erase(entry);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
// Get the ongoing queries and their callbacks.
|
|
std::vector<CallbackInfo> callbacks;
|
|
auto its = callbacks_.equal_range(name);
|
|
for (auto it = its.first; it != its.second; it++) {
|
|
callbacks.push_back(std::move(it->second));
|
|
}
|
|
callbacks_.erase(name);
|
|
|
|
// Remove all known questions and answers.
|
|
questions_.erase(name);
|
|
records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
|
|
|
|
// Restart the queries.
|
|
for (const auto& cb : callbacks) {
|
|
StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
OSP_DCHECK(message.type() == MessageType::Response);
|
|
|
|
OSP_DVLOG << "Received mDNS Response message with "
|
|
<< message.answers().size() << " answers and "
|
|
<< message.additional_records().size()
|
|
<< " additional records. Processing...";
|
|
|
|
std::vector<MdnsRecord> records_to_process;
|
|
|
|
// Add any records that are relevant for this querier.
|
|
bool found_relevant_records = false;
|
|
for (const MdnsRecord& record : message.answers()) {
|
|
if (ShouldAnswerRecordBeProcessed(record)) {
|
|
records_to_process.push_back(record);
|
|
found_relevant_records = true;
|
|
}
|
|
}
|
|
|
|
// If any of the message's answers are relevant, add all additional records.
|
|
// Else, since the message has already been received and parsed, use any
|
|
// individual records relevant to this querier to update the cache.
|
|
for (const MdnsRecord& record : message.additional_records()) {
|
|
if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
|
|
records_to_process.push_back(record);
|
|
}
|
|
}
|
|
|
|
// Drop NSEC records associated with a non-NSEC record of the same type.
|
|
RemoveInvalidNsecFlags(&records_to_process);
|
|
|
|
// Process all remaining records.
|
|
for (const MdnsRecord& record_to_process : records_to_process) {
|
|
ProcessRecord(record_to_process);
|
|
}
|
|
|
|
OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
|
|
<< " records accepted)!";
|
|
|
|
// TODO(crbug.com/openscreen/83): Check authority records.
|
|
}
|
|
|
|
bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
|
|
// First, accept the record if it's associated with an ongoing question.
|
|
const auto questions_range = questions_.equal_range(answer.name());
|
|
const auto it = std::find_if(
|
|
questions_range.first, questions_range.second,
|
|
[&answer](const auto& pair) {
|
|
return (pair.second->question().dns_type() == DnsType::kANY ||
|
|
IsNegativeResponseFor(answer,
|
|
pair.second->question().dns_type()) ||
|
|
pair.second->question().dns_type() == answer.dns_type()) &&
|
|
(pair.second->question().dns_class() == DnsClass::kANY ||
|
|
pair.second->question().dns_class() == answer.dns_class());
|
|
});
|
|
if (it != questions_range.second) {
|
|
return true;
|
|
}
|
|
|
|
// If not, check if it corresponds to an already existing record. This is
|
|
// required because records which are already stored may either have been
|
|
// received in an additional records section, or are associated with a query
|
|
// which is no longer active.
|
|
std::vector<DnsType> types{answer.dns_type()};
|
|
if (answer.dns_type() == DnsType::kNSEC) {
|
|
const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
|
|
types = nsec_rdata.types();
|
|
}
|
|
|
|
for (DnsType type : types) {
|
|
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
|
|
records_.Find(answer.name(), type, answer.dns_class());
|
|
if (!trackers.empty()) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// In all other cases, the record isn't relevant. Drop it.
|
|
return false;
|
|
}
|
|
|
|
void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
|
|
const MdnsRecord& record) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
if (!tracker->is_negative_response()) {
|
|
ProcessCallbacks(record, RecordChangedEvent::kExpired);
|
|
}
|
|
|
|
records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
|
|
return tracker == &it_tracker;
|
|
});
|
|
}
|
|
|
|
void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
// Skip all records that can't be processed.
|
|
if (!CanBeProcessed(record.dns_type())) {
|
|
return;
|
|
}
|
|
|
|
// Ignore NSEC records if the embedder has configured us to do so.
|
|
if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
|
|
return;
|
|
}
|
|
|
|
// Get the types which the received record is associated with. In most cases
|
|
// this will only be the type of the provided record, but in the case of
|
|
// NSEC records this will be all records which the record dictates the
|
|
// nonexistence of.
|
|
std::vector<DnsType> types;
|
|
int types_count = 0;
|
|
const DnsType* types_ptr = nullptr;
|
|
if (record.dns_type() == DnsType::kNSEC) {
|
|
const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
|
|
if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
|
|
DnsType::kANY) != nsec_rdata.types().end()) {
|
|
types_ptr = kTranslatedNsecAnyQueryTypes.data();
|
|
types_count = kTranslatedNsecAnyQueryTypes.size();
|
|
} else {
|
|
types_ptr = nsec_rdata.types().data();
|
|
types_count = nsec_rdata.types().size();
|
|
}
|
|
} else {
|
|
types.push_back(record.dns_type());
|
|
types_ptr = types.data();
|
|
types_count = types.size();
|
|
}
|
|
|
|
// Apply the update for each type that the record is associated with.
|
|
for (int i = 0; i < types_count; ++i) {
|
|
DnsType dns_type = types_ptr[i];
|
|
switch (record.record_type()) {
|
|
case RecordType::kShared: {
|
|
ProcessSharedRecord(record, dns_type);
|
|
break;
|
|
}
|
|
case RecordType::kUnique: {
|
|
ProcessUniqueRecord(record, dns_type);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
|
|
DnsType dns_type) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
OSP_DCHECK(record.record_type() == RecordType::kShared);
|
|
|
|
// By design, NSEC records are never shared records.
|
|
if (record.dns_type() == DnsType::kNSEC) {
|
|
return;
|
|
}
|
|
|
|
// For any records updated, this host already has this shared record. Since
|
|
// the RDATA matches, this is only a TTL update.
|
|
auto check = [&record](const MdnsRecordTracker& tracker) {
|
|
return record.dns_type() == tracker.dns_type() &&
|
|
record.dns_class() == tracker.dns_class() &&
|
|
record.rdata() == tracker.rdata();
|
|
};
|
|
auto updated_count = records_.Update(record, std::move(check));
|
|
|
|
if (!updated_count) {
|
|
// Have never before seen this shared record, insert a new one.
|
|
AddRecord(record, dns_type);
|
|
ProcessCallbacks(record, RecordChangedEvent::kCreated);
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
|
|
DnsType dns_type) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
OSP_DCHECK(record.record_type() == RecordType::kUnique);
|
|
|
|
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
|
|
records_.Find(record.name(), dns_type, record.dns_class());
|
|
size_t num_records_for_key = trackers.size();
|
|
|
|
// Have not seen any records with this key before. This case is expected the
|
|
// first time a record is received.
|
|
if (num_records_for_key == size_t{0}) {
|
|
const bool will_exist = record.dns_type() != DnsType::kNSEC;
|
|
AddRecord(record, dns_type);
|
|
if (will_exist) {
|
|
ProcessCallbacks(record, RecordChangedEvent::kCreated);
|
|
}
|
|
} else if (num_records_for_key == size_t{1}) {
|
|
// There is exactly one tracker associated with this key. This is the
|
|
// expected case when a record matching this one has already been seen.
|
|
ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
|
|
} else {
|
|
// Multiple records with the same key.
|
|
ProcessMultiTrackedUniqueRecord(record, dns_type);
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
|
|
const MdnsRecord& record,
|
|
const MdnsRecordTracker& tracker) {
|
|
const bool existed_previously = !tracker.is_negative_response();
|
|
const bool will_exist = record.dns_type() != DnsType::kNSEC;
|
|
|
|
// Calculate the callback to call on record update success while the old
|
|
// record still exists.
|
|
MdnsRecord record_for_callback = record;
|
|
if (existed_previously && !will_exist) {
|
|
record_for_callback =
|
|
MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
|
|
tracker.record_type(), tracker.ttl(), tracker.rdata());
|
|
}
|
|
|
|
auto on_rdata_change = [this, r = std::move(record_for_callback),
|
|
existed_previously,
|
|
will_exist](const MdnsRecordTracker& tracker) {
|
|
// If RDATA on the record is different, notify that the record has
|
|
// been updated.
|
|
if (existed_previously && will_exist) {
|
|
ProcessCallbacks(r, RecordChangedEvent::kUpdated);
|
|
} else if (existed_previously) {
|
|
// Do not expire the tracker, because it still holds an NSEC record.
|
|
ProcessCallbacks(r, RecordChangedEvent::kExpired);
|
|
} else if (will_exist) {
|
|
ProcessCallbacks(r, RecordChangedEvent::kCreated);
|
|
}
|
|
};
|
|
|
|
int updated_count = records_.Update(
|
|
record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
|
|
std::move(on_rdata_change));
|
|
OSP_DCHECK_EQ(updated_count, 1);
|
|
}
|
|
|
|
void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
|
|
DnsType dns_type) {
|
|
auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
|
|
return tracker.dns_type() == dns_type &&
|
|
tracker.dns_class() == record.dns_class() &&
|
|
tracker.rdata() == record.rdata();
|
|
};
|
|
int update_count = records_.Update(
|
|
record, std::move(update_check),
|
|
[](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
|
|
OSP_DCHECK_LE(update_count, 1);
|
|
|
|
auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
|
|
return tracker.dns_type() == dns_type &&
|
|
tracker.dns_class() == record.dns_class() &&
|
|
tracker.rdata() != record.rdata();
|
|
};
|
|
int expire_count =
|
|
records_.ExpireSoon(record.name(), std::move(expire_check));
|
|
OSP_DCHECK_GE(expire_count, 1);
|
|
|
|
// Did not find an existing record to update.
|
|
if (!update_count && !expire_count) {
|
|
AddRecord(record, dns_type);
|
|
if (record.dns_type() != DnsType::kNSEC) {
|
|
ProcessCallbacks(record, RecordChangedEvent::kCreated);
|
|
}
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
|
|
RecordChangedEvent event) {
|
|
OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
|
|
|
|
std::vector<PendingQueryChange> pending_changes;
|
|
auto callbacks_it = callbacks_.equal_range(record.name());
|
|
for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
|
|
const CallbackInfo& callback_info = entry->second;
|
|
if ((callback_info.dns_type == DnsType::kANY ||
|
|
record.dns_type() == callback_info.dns_type) &&
|
|
(callback_info.dns_class == DnsClass::kANY ||
|
|
record.dns_class() == callback_info.dns_class)) {
|
|
std::vector<PendingQueryChange> new_changes =
|
|
callback_info.callback->OnRecordChanged(record, event);
|
|
pending_changes.insert(pending_changes.end(), new_changes.begin(),
|
|
new_changes.end());
|
|
}
|
|
}
|
|
|
|
ApplyPendingChanges(std::move(pending_changes));
|
|
}
|
|
|
|
void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
|
|
auto tracker = std::make_unique<MdnsQuestionTracker>(
|
|
question, sender_, task_runner_, now_function_, random_delay_, config_);
|
|
MdnsQuestionTracker* ptr = tracker.get();
|
|
questions_.emplace(question.name(), std::move(tracker));
|
|
|
|
// Let all records associated with this question know that there is a new
|
|
// query that can be used for their refresh.
|
|
std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
|
|
records_.Find(question.name(), question.dns_type(), question.dns_class());
|
|
for (const MdnsRecordTracker& tracker : trackers) {
|
|
// NOTE: When the pointed to object is deleted, its dtor removes itself
|
|
// from all associated records.
|
|
ptr->AddAssociatedRecord(&tracker);
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
|
|
// Add the new record.
|
|
const auto& tracker = records_.StartTracking(record, type);
|
|
|
|
// Let all questions associated with this record know that there is a new
|
|
// record that answers them (for known answer suppression).
|
|
auto query_it = questions_.equal_range(record.name());
|
|
for (auto entry = query_it.first; entry != query_it.second; ++entry) {
|
|
const MdnsQuestion& query = entry->second->question();
|
|
const bool is_relevant_type =
|
|
type == DnsType::kANY || type == query.dns_type();
|
|
const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
|
|
record.dns_class() == query.dns_class();
|
|
if (is_relevant_type && is_relevant_class) {
|
|
// NOTE: When the pointed to object is deleted, its dtor removes itself
|
|
// from all associated queries.
|
|
entry->second->AddAssociatedRecord(&tracker);
|
|
}
|
|
}
|
|
}
|
|
|
|
void MdnsQuerier::ApplyPendingChanges(
|
|
std::vector<PendingQueryChange> pending_changes) {
|
|
for (auto& pending_change : pending_changes) {
|
|
switch (pending_change.change_type) {
|
|
case PendingQueryChange::kStartQuery:
|
|
StartQuery(std::move(pending_change.name), pending_change.dns_type,
|
|
pending_change.dns_class, pending_change.callback);
|
|
break;
|
|
case PendingQueryChange::kStopQuery:
|
|
StopQuery(std::move(pending_change.name), pending_change.dns_type,
|
|
pending_change.dns_class, pending_change.callback);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace discovery
|
|
} // namespace openscreen
|