// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ #include #include #include #include #include #include "absl/types/optional.h" #include "osp/public/message_demuxer.h" #include "osp/public/network_service_manager.h" #include "osp/public/protocol_connection.h" #include "platform/base/error.h" #include "platform/base/macros.h" #include "util/osp_logging.h" namespace openscreen { namespace osp { template using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); // Provides a uniform way of accessing import properties of a request/response // message pair from a template: request encode function, response decode // function, request serializable data member. template struct DefaultRequestCoderTraits { public: using RequestMsgType = typename T::RequestMsgType; static constexpr MessageEncodingFunction kEncoder = T::kEncoder; static constexpr MessageDecodingFunction kDecoder = T::kDecoder; static const RequestMsgType* serial_request(const T& data) { return &data.request; } static RequestMsgType* serial_request(T& data) { return &data.request; } }; // Provides a wrapper for the common pattern of sending a request message and // waiting for a response message with a matching |request_id| field. It also // handles the business of queueing messages to be sent until a protocol // connection is available. // // Messages are written using WriteMessage. This will queue messages if there // is no protocol connection or write them immediately if there is. When a // matching response is received via the MessageDemuxer (taken from the global // ProtocolConnectionClient), OnMatchedResponse is called on the provided // Delegate object along with the original request that it matches. template > class RequestResponseHandler : public MessageDemuxer::MessageCallback { public: class Delegate { public: virtual ~Delegate() = default; virtual void OnMatchedResponse(RequestT* request, typename RequestT::ResponseMsgType* response, uint64_t endpoint_id) = 0; virtual void OnError(RequestT* request, Error error) = 0; }; explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} ~RequestResponseHandler() { Reset(); } void Reset() { connection_ = nullptr; for (auto& message : to_send_) { delegate_->OnError(&message.request, Error::Code::kRequestCancelled); } to_send_.clear(); for (auto& message : sent_) { delegate_->OnError(&message.request, Error::Code::kRequestCancelled); } sent_.clear(); response_watch_ = MessageDemuxer::MessageWatch(); } // Write a message to the underlying protocol connection, or queue it until // one is provided via SetConnection. If |id| is provided, it can be used to // cancel the message via CancelMessage. template typename std::enable_if< !std::is_lvalue_reference::value && std::is_same::type, RequestT>::value, Error>::type WriteMessage(absl::optional id, RequestTRval&& message) { auto* request_msg = RequestCoderTraits::serial_request(message); if (connection_) { request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); Error result = connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); if (!result.ok()) { return result; } sent_.emplace_back(RequestWithId{id, std::move(message)}); EnsureResponseWatch(); } else { to_send_.emplace_back(RequestWithId{id, std::move(message)}); } return Error::None(); } template typename std::enable_if< !std::is_lvalue_reference::value && std::is_same::type, RequestT>::value, Error>::type WriteMessage(RequestTRval&& message) { return WriteMessage(absl::nullopt, std::move(message)); } // Remove the message that was originally written with |id| from the send and // sent queues so that we are no longer looking for a response. void CancelMessage(uint64_t id) { to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), [&id](const RequestWithId& msg) { return id == msg.id; }), to_send_.end()); sent_.erase(std::remove_if( sent_.begin(), sent_.end(), [&id](const RequestWithId& msg) { return id == msg.id; }), sent_.end()); if (sent_.empty()) { response_watch_ = MessageDemuxer::MessageWatch(); } } // Assign a ProtocolConnection to this handler for writing messages. void SetConnection(ProtocolConnection* connection) { connection_ = connection; for (auto& message : to_send_) { auto* request_msg = RequestCoderTraits::serial_request(message.request); request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); Error result = connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); if (result.ok()) { sent_.emplace_back(std::move(message)); } else { delegate_->OnError(&message.request, result); } } if (!to_send_.empty()) { EnsureResponseWatch(); } to_send_.clear(); } // MessageDemuxer::MessageCallback overrides. ErrorOr OnStreamMessage(uint64_t endpoint_id, uint64_t connection_id, msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, Clock::time_point now) override { if (message_type != RequestT::kResponseType) { return 0; } typename RequestT::ResponseMsgType response; ssize_t result = RequestCoderTraits::kDecoder(buffer, buffer_size, &response); if (result < 0) { return 0; } auto it = std::find_if( sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { return RequestCoderTraits::serial_request(msg.request)->request_id == response.request_id; }); if (it != sent_.end()) { delegate_->OnMatchedResponse(&it->request, &response, connection_->endpoint_id()); sent_.erase(it); if (sent_.empty()) { response_watch_ = MessageDemuxer::MessageWatch(); } } else { OSP_LOG_WARN << "got response for unknown request id: " << response.request_id; } return result; } private: struct RequestWithId { absl::optional id; RequestT request; }; void EnsureResponseWatch() { if (!response_watch_) { response_watch_ = NetworkServiceManager::Get() ->GetProtocolConnectionClient() ->message_demuxer() ->WatchMessageType(connection_->endpoint_id(), RequestT::kResponseType, this); } } uint64_t GetNextRequestId(uint64_t endpoint_id) { return NetworkServiceManager::Get() ->GetProtocolConnectionClient() ->endpoint_request_ids() ->GetNextRequestId(endpoint_id); } ProtocolConnection* connection_ = nullptr; Delegate* const delegate_; std::vector to_send_; std::vector sent_; MessageDemuxer::MessageWatch response_watch_; OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); }; } // namespace osp } // namespace openscreen #endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_