#!/usr/bin/env python3 # # Copyright 2020 - The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from datetime import datetime, timedelta from mobly import signals from threading import Condition from cert.event_stream import static_remaining_time_delta from cert.truth import assertThat class IHasBehaviors(ABC): @abstractmethod def get_behaviors(self): pass def anything(): return lambda obj: True def when(has_behaviors): assertThat(isinstance(has_behaviors, IHasBehaviors)).isTrue() return has_behaviors.get_behaviors() def IGNORE_UNHANDLED(obj): pass class SingleArgumentBehavior(object): def __init__(self, reply_stage_factory): self._reply_stage_factory = reply_stage_factory self._instances = [] self._invoked_obj = [] self._invoked_condition = Condition() self.set_default_to_crash() def begin(self, matcher): return PersistenceStage(self, matcher, self._reply_stage_factory) def append(self, behavior_instance): self._instances.append(behavior_instance) def set_default(self, fn): assertThat(fn).isNotNone() self._default_fn = fn def set_default_to_crash(self): self._default_fn = None def set_default_to_ignore(self): self._default_fn = IGNORE_UNHANDLED def run(self, obj): for instance in self._instances: if instance.try_run(obj): self.__obj_invoked(obj) return if self._default_fn is not None: # IGNORE_UNHANDLED is also a default fn self._default_fn(obj) self.__obj_invoked(obj) else: raise signals.TestFailure( "%s: behavior for %s went unhandled" % (self._reply_stage_factory().__class__.__name__, obj), extras=None) def __obj_invoked(self, obj): self._invoked_condition.acquire() self._invoked_obj.append(obj) self._invoked_condition.notify() self._invoked_condition.release() def wait_until_invoked(self, matcher, times, timeout): end_time = datetime.now() + timeout invoked_times = 0 while datetime.now() < end_time and invoked_times < times: remaining = static_remaining_time_delta(end_time) invoked_times = sum((matcher(i) for i in self._invoked_obj)) self._invoked_condition.acquire() self._invoked_condition.wait(remaining.total_seconds()) self._invoked_condition.release() return invoked_times == times class PersistenceStage(object): def __init__(self, behavior, matcher, reply_stage_factory): self._behavior = behavior self._matcher = matcher self._reply_stage_factory = reply_stage_factory def then(self, times=1): reply_stage = self._reply_stage_factory() reply_stage.init(self._behavior, self._matcher, times) return reply_stage def always(self): return self.then(times=-1) class ReplyStage(object): def init(self, behavior, matcher, persistence): self._behavior = behavior self._matcher = matcher self._persistence = persistence def _commit(self, fn): self._behavior.append(BehaviorInstance(self._matcher, self._persistence, fn)) class BehaviorInstance(object): def __init__(self, matcher, persistence, fn): self._matcher = matcher self._persistence = persistence self._fn = fn self._called_count = 0 def try_run(self, obj): if not self._matcher(obj): return False if self._persistence >= 0: if self._called_count >= self._persistence: return False self._called_count += 1 self._fn(obj) return True class BoundVerificationStage(object): def __init__(self, behavior, matcher, timeout): self._behavior = behavior self._matcher = matcher self._timeout = timeout def times(self, times=1): return self._behavior.wait_until_invoked(self._matcher, times, self._timeout) class WaitForBehaviorSubject(object): def __init__(self, behaviors, timeout): self._behaviors = behaviors self._timeout = timeout def __getattr__(self, item): behavior = getattr(self._behaviors, item + "_behavior") t = self._timeout return lambda matcher: BoundVerificationStage(behavior, matcher, t) def wait_until(i_has_behaviors, timeout=timedelta(seconds=3)): return WaitForBehaviorSubject(i_has_behaviors.get_behaviors(), timeout)