// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_ #define MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_ #include #include #include "base/bind.h" #include "base/callback.h" #include "base/macros.h" #include "base/memory/ptr_util.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/interface_ptr.h" #include "mojo/public/cpp/bindings/interface_request.h" #include "mojo/public/cpp/bindings/message.h" namespace mojo { template struct BindingSetTraits; template struct BindingSetTraits> { using ProxyType = InterfacePtr; using RequestType = InterfaceRequest; using BindingType = Binding; using ImplPointerType = typename BindingType::ImplPointerType; static RequestType MakeRequest(ProxyType* proxy) { return mojo::MakeRequest(proxy); } }; using BindingId = size_t; template struct BindingSetContextTraits { using Type = ContextType; static constexpr bool SupportsContext() { return true; } }; template <> struct BindingSetContextTraits { // NOTE: This choice of Type only matters insofar as it affects the size of // the |context_| field of a BindingSetBase::Entry with void context. The // context value is never used in this case. using Type = bool; static constexpr bool SupportsContext() { return false; } }; // Generic definition used for BindingSet and AssociatedBindingSet to own a // collection of bindings which point to the same implementation. // // If |ContextType| is non-void, then every added binding must include a context // value of that type, and |dispatch_context()| will return that value during // the extent of any message dispatch targeting that specific binding. template class BindingSetBase { public: using ContextTraits = BindingSetContextTraits; using Context = typename ContextTraits::Type; using PreDispatchCallback = base::Callback; using Traits = BindingSetTraits; using ProxyType = typename Traits::ProxyType; using RequestType = typename Traits::RequestType; using ImplPointerType = typename Traits::ImplPointerType; BindingSetBase() : weak_ptr_factory_(this) {} void set_connection_error_handler(base::RepeatingClosure error_handler) { error_handler_ = std::move(error_handler); error_with_reason_handler_.Reset(); } void set_connection_error_with_reason_handler( RepeatingConnectionErrorWithReasonCallback error_handler) { error_with_reason_handler_ = std::move(error_handler); error_handler_.Reset(); } // Sets a callback to be invoked immediately before dispatching any message or // error received by any of the bindings in the set. This may only be used // with a non-void |ContextType|. void set_pre_dispatch_handler(const PreDispatchCallback& handler) { static_assert(ContextTraits::SupportsContext(), "Pre-dispatch handler usage requires non-void context type."); pre_dispatch_handler_ = handler; } // Adds a new binding to the set which binds |request| to |impl| with no // additional context. BindingId AddBinding(ImplPointerType impl, RequestType request) { static_assert(!ContextTraits::SupportsContext(), "Context value required for non-void context type."); return AddBindingImpl(std::move(impl), std::move(request), false); } // Adds a new binding associated with |context|. BindingId AddBinding(ImplPointerType impl, RequestType request, Context context) { static_assert(ContextTraits::SupportsContext(), "Context value unsupported for void context type."); return AddBindingImpl(std::move(impl), std::move(request), std::move(context)); } // Removes a binding from the set. Note that this is safe to call even if the // binding corresponding to |id| has already been removed. // // Returns |true| if the binding was removed and |false| if it didn't exist. bool RemoveBinding(BindingId id) { auto it = bindings_.find(id); if (it == bindings_.end()) return false; bindings_.erase(it); return true; } // Swaps the interface implementation with a different one, to allow tests // to modify behavior. // // Returns the existing interface implementation to the caller. ImplPointerType SwapImplForTesting(BindingId id, ImplPointerType new_impl) { auto it = bindings_.find(id); if (it == bindings_.end()) return nullptr; return it->second->SwapImplForTesting(new_impl); } void CloseAllBindings() { bindings_.clear(); } bool empty() const { return bindings_.empty(); } size_t size() const { return bindings_.size(); } // Implementations may call this when processing a dispatched message or // error. During the extent of message or error dispatch, this will return the // context associated with the specific binding which received the message or // error. Use AddBinding() to associated a context with a specific binding. const Context& dispatch_context() const { static_assert(ContextTraits::SupportsContext(), "dispatch_context() requires non-void context type."); DCHECK(dispatch_context_); return *dispatch_context_; } // Implementations may call this when processing a dispatched message or // error. During the extent of message or error dispatch, this will return the // BindingId of the specific binding which received the message or error. BindingId dispatch_binding() const { DCHECK(dispatch_context_); return dispatch_binding_; } // Reports the currently dispatching Message as bad and closes the binding the // message was received from. Note that this is only legal to call from // directly within the stack frame of a message dispatch. If you need to do // asynchronous work before you can determine the legitimacy of a message, use // GetBadMessageCallback() and retain its result until you're ready to invoke // or discard it. void ReportBadMessage(const std::string& error) { GetBadMessageCallback().Run(error); } // Acquires a callback which may be run to report the currently dispatching // Message as bad and close the binding the message was received from. Note // that this is only legal to call from directly within the stack frame of a // message dispatch, but the returned callback may be called exactly once any // time thereafter as long as the binding set itself hasn't been destroyed yet // to report the message as bad. This may only be called once per message. // The returned callback must be called on the BindingSet's own sequence. ReportBadMessageCallback GetBadMessageCallback() { DCHECK(dispatch_context_); return base::BindOnce( [](ReportBadMessageCallback error_callback, base::WeakPtr binding_set, BindingId binding_id, const std::string& error) { std::move(error_callback).Run(error); if (binding_set) binding_set->RemoveBinding(binding_id); }, mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(), dispatch_binding()); } void FlushForTesting() { DCHECK(!is_flushing_); is_flushing_ = true; for (auto& binding : bindings_) if (binding.second) binding.second->FlushForTesting(); is_flushing_ = false; // Clean up any bindings that were destroyed. for (auto it = bindings_.begin(); it != bindings_.end();) { if (!it->second) it = bindings_.erase(it); else ++it; } } private: friend class Entry; class Entry { public: Entry(ImplPointerType impl, RequestType request, BindingSetBase* binding_set, BindingId binding_id, Context context) : binding_(std::move(impl), std::move(request)), binding_set_(binding_set), binding_id_(binding_id), context_(std::move(context)) { binding_.AddFilter(std::make_unique(this)); binding_.set_connection_error_with_reason_handler( base::BindOnce(&Entry::OnConnectionError, base::Unretained(this))); } void FlushForTesting() { binding_.FlushForTesting(); } ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { return binding_.SwapImplForTesting(new_impl); } private: class DispatchFilter : public MessageReceiver { public: explicit DispatchFilter(Entry* entry) : entry_(entry) {} ~DispatchFilter() override {} private: // MessageReceiver: bool Accept(Message* message) override { entry_->WillDispatch(); return true; } Entry* entry_; DISALLOW_COPY_AND_ASSIGN(DispatchFilter); }; void WillDispatch() { binding_set_->SetDispatchContext(&context_, binding_id_); } void OnConnectionError(uint32_t custom_reason, const std::string& description) { WillDispatch(); binding_set_->OnConnectionError(binding_id_, custom_reason, description); } BindingType binding_; BindingSetBase* const binding_set_; const BindingId binding_id_; Context const context_; DISALLOW_COPY_AND_ASSIGN(Entry); }; void SetDispatchContext(const Context* context, BindingId binding_id) { dispatch_context_ = context; dispatch_binding_ = binding_id; if (!pre_dispatch_handler_.is_null()) pre_dispatch_handler_.Run(*context); } BindingId AddBindingImpl(ImplPointerType impl, RequestType request, Context context) { BindingId id = next_binding_id_++; DCHECK_GE(next_binding_id_, 0u); auto entry = std::make_unique(std::move(impl), std::move(request), this, id, std::move(context)); bindings_.insert(std::make_pair(id, std::move(entry))); return id; } void OnConnectionError(BindingId id, uint32_t custom_reason, const std::string& description) { auto it = bindings_.find(id); DCHECK(it != bindings_.end()); // We keep the Entry alive throughout error dispatch. std::unique_ptr entry = std::move(it->second); if (!is_flushing_) bindings_.erase(it); if (error_handler_) { error_handler_.Run(); } else if (error_with_reason_handler_) { error_with_reason_handler_.Run(custom_reason, description); } } base::RepeatingClosure error_handler_; RepeatingConnectionErrorWithReasonCallback error_with_reason_handler_; PreDispatchCallback pre_dispatch_handler_; BindingId next_binding_id_ = 0; std::map> bindings_; bool is_flushing_ = false; const Context* dispatch_context_ = nullptr; BindingId dispatch_binding_; base::WeakPtrFactory weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(BindingSetBase); }; template using BindingSet = BindingSetBase, ContextType>; } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_