You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
8.1 KiB
230 lines
8.1 KiB
// Copyright 2019 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
|
|
#define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <type_traits>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/types/optional.h"
|
|
#include "osp/public/message_demuxer.h"
|
|
#include "osp/public/network_service_manager.h"
|
|
#include "osp/public/protocol_connection.h"
|
|
#include "platform/base/error.h"
|
|
#include "platform/base/macros.h"
|
|
#include "util/osp_logging.h"
|
|
|
|
namespace openscreen {
|
|
namespace osp {
|
|
|
|
template <typename T>
|
|
using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*);
|
|
|
|
// Provides a uniform way of accessing import properties of a request/response
|
|
// message pair from a template: request encode function, response decode
|
|
// function, request serializable data member.
|
|
template <typename T>
|
|
struct DefaultRequestCoderTraits {
|
|
public:
|
|
using RequestMsgType = typename T::RequestMsgType;
|
|
static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =
|
|
T::kEncoder;
|
|
static constexpr MessageDecodingFunction<typename T::ResponseMsgType>
|
|
kDecoder = T::kDecoder;
|
|
|
|
static const RequestMsgType* serial_request(const T& data) {
|
|
return &data.request;
|
|
}
|
|
static RequestMsgType* serial_request(T& data) { return &data.request; }
|
|
};
|
|
|
|
// Provides a wrapper for the common pattern of sending a request message and
|
|
// waiting for a response message with a matching |request_id| field. It also
|
|
// handles the business of queueing messages to be sent until a protocol
|
|
// connection is available.
|
|
//
|
|
// Messages are written using WriteMessage. This will queue messages if there
|
|
// is no protocol connection or write them immediately if there is. When a
|
|
// matching response is received via the MessageDemuxer (taken from the global
|
|
// ProtocolConnectionClient), OnMatchedResponse is called on the provided
|
|
// Delegate object along with the original request that it matches.
|
|
template <typename RequestT,
|
|
typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>>
|
|
class RequestResponseHandler : public MessageDemuxer::MessageCallback {
|
|
public:
|
|
class Delegate {
|
|
public:
|
|
virtual ~Delegate() = default;
|
|
|
|
virtual void OnMatchedResponse(RequestT* request,
|
|
typename RequestT::ResponseMsgType* response,
|
|
uint64_t endpoint_id) = 0;
|
|
virtual void OnError(RequestT* request, Error error) = 0;
|
|
};
|
|
|
|
explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {}
|
|
~RequestResponseHandler() { Reset(); }
|
|
|
|
void Reset() {
|
|
connection_ = nullptr;
|
|
for (auto& message : to_send_) {
|
|
delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
|
|
}
|
|
to_send_.clear();
|
|
for (auto& message : sent_) {
|
|
delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
|
|
}
|
|
sent_.clear();
|
|
response_watch_ = MessageDemuxer::MessageWatch();
|
|
}
|
|
|
|
// Write a message to the underlying protocol connection, or queue it until
|
|
// one is provided via SetConnection. If |id| is provided, it can be used to
|
|
// cancel the message via CancelMessage.
|
|
template <typename RequestTRval>
|
|
typename std::enable_if<
|
|
!std::is_lvalue_reference<RequestTRval>::value &&
|
|
std::is_same<typename std::decay<RequestTRval>::type,
|
|
RequestT>::value,
|
|
Error>::type
|
|
WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) {
|
|
auto* request_msg = RequestCoderTraits::serial_request(message);
|
|
if (connection_) {
|
|
request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
|
|
Error result =
|
|
connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
|
|
if (!result.ok()) {
|
|
return result;
|
|
}
|
|
sent_.emplace_back(RequestWithId{id, std::move(message)});
|
|
EnsureResponseWatch();
|
|
} else {
|
|
to_send_.emplace_back(RequestWithId{id, std::move(message)});
|
|
}
|
|
return Error::None();
|
|
}
|
|
|
|
template <typename RequestTRval>
|
|
typename std::enable_if<
|
|
!std::is_lvalue_reference<RequestTRval>::value &&
|
|
std::is_same<typename std::decay<RequestTRval>::type,
|
|
RequestT>::value,
|
|
Error>::type
|
|
WriteMessage(RequestTRval&& message) {
|
|
return WriteMessage(absl::nullopt, std::move(message));
|
|
}
|
|
|
|
// Remove the message that was originally written with |id| from the send and
|
|
// sent queues so that we are no longer looking for a response.
|
|
void CancelMessage(uint64_t id) {
|
|
to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(),
|
|
[&id](const RequestWithId& msg) {
|
|
return id == msg.id;
|
|
}),
|
|
to_send_.end());
|
|
sent_.erase(std::remove_if(
|
|
sent_.begin(), sent_.end(),
|
|
[&id](const RequestWithId& msg) { return id == msg.id; }),
|
|
sent_.end());
|
|
if (sent_.empty()) {
|
|
response_watch_ = MessageDemuxer::MessageWatch();
|
|
}
|
|
}
|
|
|
|
// Assign a ProtocolConnection to this handler for writing messages.
|
|
void SetConnection(ProtocolConnection* connection) {
|
|
connection_ = connection;
|
|
for (auto& message : to_send_) {
|
|
auto* request_msg = RequestCoderTraits::serial_request(message.request);
|
|
request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
|
|
Error result =
|
|
connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
|
|
if (result.ok()) {
|
|
sent_.emplace_back(std::move(message));
|
|
} else {
|
|
delegate_->OnError(&message.request, result);
|
|
}
|
|
}
|
|
if (!to_send_.empty()) {
|
|
EnsureResponseWatch();
|
|
}
|
|
to_send_.clear();
|
|
}
|
|
|
|
// MessageDemuxer::MessageCallback overrides.
|
|
ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
|
|
uint64_t connection_id,
|
|
msgs::Type message_type,
|
|
const uint8_t* buffer,
|
|
size_t buffer_size,
|
|
Clock::time_point now) override {
|
|
if (message_type != RequestT::kResponseType) {
|
|
return 0;
|
|
}
|
|
typename RequestT::ResponseMsgType response;
|
|
ssize_t result =
|
|
RequestCoderTraits::kDecoder(buffer, buffer_size, &response);
|
|
if (result < 0) {
|
|
return 0;
|
|
}
|
|
auto it = std::find_if(
|
|
sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) {
|
|
return RequestCoderTraits::serial_request(msg.request)->request_id ==
|
|
response.request_id;
|
|
});
|
|
if (it != sent_.end()) {
|
|
delegate_->OnMatchedResponse(&it->request, &response,
|
|
connection_->endpoint_id());
|
|
sent_.erase(it);
|
|
if (sent_.empty()) {
|
|
response_watch_ = MessageDemuxer::MessageWatch();
|
|
}
|
|
} else {
|
|
OSP_LOG_WARN << "got response for unknown request id: "
|
|
<< response.request_id;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
struct RequestWithId {
|
|
absl::optional<uint64_t> id;
|
|
RequestT request;
|
|
};
|
|
|
|
void EnsureResponseWatch() {
|
|
if (!response_watch_) {
|
|
response_watch_ = NetworkServiceManager::Get()
|
|
->GetProtocolConnectionClient()
|
|
->message_demuxer()
|
|
->WatchMessageType(connection_->endpoint_id(),
|
|
RequestT::kResponseType, this);
|
|
}
|
|
}
|
|
|
|
uint64_t GetNextRequestId(uint64_t endpoint_id) {
|
|
return NetworkServiceManager::Get()
|
|
->GetProtocolConnectionClient()
|
|
->endpoint_request_ids()
|
|
->GetNextRequestId(endpoint_id);
|
|
}
|
|
|
|
ProtocolConnection* connection_ = nullptr;
|
|
Delegate* const delegate_;
|
|
std::vector<RequestWithId> to_send_;
|
|
std::vector<RequestWithId> sent_;
|
|
MessageDemuxer::MessageWatch response_watch_;
|
|
|
|
OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler);
|
|
};
|
|
|
|
} // namespace osp
|
|
} // namespace openscreen
|
|
|
|
#endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
|