testslide/mock_callable.py (938 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import functools
import inspect
from inspect import Traceback
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from unittest.mock import Mock
import testslide
from testslide.lib import _validate_return_type, _wrap_signature_and_type_validation
from testslide.strict_mock import StrictMock
from .lib import CoroutineValueError, _bail_if_private, _is_a_builtin
from .patch import _is_instance_method, _patch
if TYPE_CHECKING:
from testslide.matchers import RegexMatches # noqa: F401
from testslide.mock_constructor import _MockConstructorDSL # noqa: F401
def mock_callable(
target: Any,
method: str,
allow_private: bool = False,
# type_validation accepted values:
# * None: type validation will be enabled except if target is a StrictMock
# with disabled type validation
# * True: type validation will be enabled (regardless of target type)
# * False: type validation will be disabled
type_validation: Optional[bool] = None,
) -> "_MockCallableDSL":
caller_frame = inspect.currentframe().f_back.f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
return _MockCallableDSL(
target,
method,
caller_frame_info,
allow_private=allow_private,
type_validation=type_validation,
)
def mock_async_callable(
target: Union[type, str],
method: str,
callable_returns_coroutine: bool = False,
allow_private: bool = False,
type_validation: bool = True,
) -> "_MockAsyncCallableDSL":
caller_frame = inspect.currentframe().f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
return _MockAsyncCallableDSL(
target,
method,
caller_frame_info,
callable_returns_coroutine,
allow_private,
type_validation,
)
_unpatchers: List[Callable] = [] # noqa T484
def _default_register_assertion(assertion: Callable) -> None:
"""
This method must be redefined by the test framework using mock_callable().
It will be called when a new assertion is defined, passing a callable as an
argument that evaluates that assertion. Every defined assertion during a test
must be called after the test code ends, and before the test finishes.
"""
raise NotImplementedError("This method must be redefined by the test framework")
register_assertion = _default_register_assertion
_call_order_assertion_registered: bool = False
_received_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []
_expected_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []
def unpatch_all_callable_mocks() -> None:
"""
This method must be called after every test unconditionally to remove all
active mock_callable() patches.
"""
global register_assertion, _default_register_assertion, _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls
register_assertion = _default_register_assertion
_call_order_assertion_registered = False
del _received_ordered_calls[:]
del _expected_ordered_calls[:]
unpatch_exceptions = []
for unpatcher in _unpatchers:
try:
unpatcher()
except Exception as e:
unpatch_exceptions.append(e)
del _unpatchers[:]
if unpatch_exceptions:
raise RuntimeError(
"Exceptions raised when unpatching: {}".format(unpatch_exceptions)
)
def _is_setup() -> bool:
global register_assertion, _default_register_assertion
return register_assertion is not _default_register_assertion
def _format_target(target: Union[str, type]) -> str:
if hasattr(target, "__repr__"):
return repr(target)
else:
return "{}.{} instance with id {}".format(
target.__module__, type(target).__name__, id(target)
)
def _format_args(indent: int, *args: Any, **kwargs: Any) -> str:
indentation = " " * indent
s = ""
if args:
s += ("{}{}\n").format(indentation, args)
if kwargs:
s += indentation + "{"
if kwargs:
s += "\n"
for k in sorted(kwargs.keys()):
s += "{} {}={},\n".format(indentation, k, repr(kwargs[k]))
s += "{}".format(indentation)
s += "}\n"
return s
def _is_coroutine(obj: Any) -> bool:
return inspect.iscoroutine(obj) or isinstance(obj, asyncio.coroutines.CoroWrapper) # type: ignore
def _is_coroutinefunction(func: Any) -> bool:
# We use asyncio.iscoroutinefunction over inspect because the next Cython version
# will return True from the asyncio variant over inspect which will return False
# FIXME We can not reliably introspect coroutine functions
# for builtins: https://bugs.python.org/issue38225
return asyncio.iscoroutinefunction(func) or _is_a_builtin(func)
##
## Exceptions
##
class UndefinedBehaviorForCall(BaseException):
"""
Raised when a mock receives a call for which no behavior was defined.
Inherits from BaseException to avoid being caught by tested code.
"""
class UnexpectedCallReceived(BaseException):
"""
Raised when a mock receives a call that it is configured not to accept.
Inherits from BaseException to avoid being caught by tested code.
"""
class UnexpectedCallArguments(BaseException):
"""
Raised when a mock receives a call with unexpected arguments.
Inherits from BaseException to avoid being caught by tested code.
"""
class NotACoroutine(BaseException):
"""
Raised when a mock that requires a coroutine is not mocked with one.
Inherits from BaseException to avoid being caught by tested code.
"""
##
## Runners
##
class _BaseRunner:
TYPE_VALIDATION = True
def __init__(
self, target: Any, method: str, original_callable: Union[Callable, Mock]
) -> None:
self.target = target
self.method = method
self.original_callable = original_callable
self.accepted_args: Optional[Tuple[Any, Any]] = None
self._call_count: int = 0
self._max_calls: Optional[int] = None
self._has_order_assertion = False
self._accept_partial_call = False
def register_call(self, *args: Any, **kwargs: Any) -> None:
global _received_ordered_calls
if self._has_order_assertion:
_received_ordered_calls.append((self.target, self.method, self))
self.inc_call_count()
@property
def call_count(self) -> int:
return self._call_count
@property
def max_calls(self) -> Optional[int]:
return self._max_calls
def _set_max_calls(self, times: int) -> None:
if not self._max_calls or times < self._max_calls:
self._max_calls = times
def inc_call_count(self) -> None:
self._call_count += 1
if self.max_calls and self._call_count > self.max_calls:
raise UnexpectedCallReceived(
(
"Unexpected call received.\n"
"{}, {}:\n"
" expected to receive at most {} calls with {}"
" but received an extra call."
).format(
_format_target(self.target),
repr(self.method),
self.max_calls,
self._args_message(),
)
)
def add_accepted_args(
self,
_accept_partial_call: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
self.accepted_args = (args, kwargs)
self._accept_partial_call = _accept_partial_call
def can_accept_args(self, *args: Any, **kwargs: Any) -> bool:
if self.accepted_args:
if self._accept_partial_call:
args_match = all(
any(elem == arg for arg in args) for elem in self.accepted_args[0]
)
kwargs_match = all(
elem in kwargs.keys()
and kwargs[elem] == self.accepted_args[1][elem]
for elem in self.accepted_args[1].keys()
)
return args_match and kwargs_match
else:
return self.accepted_args == (args, kwargs)
else:
return True
def _args_message(self) -> str:
if self.accepted_args:
return "arguments:\n{}".format(
_format_args(2, *self.accepted_args[0], **self.accepted_args[1])
)
else:
return "any arguments "
def add_exact_calls_assertion(self, times: int) -> None:
self._set_max_calls(times)
def assertion() -> None:
if times != self.call_count:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called exactly {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_at_least_calls_assertion(self, times: int) -> None:
def assertion() -> None:
if self.call_count < times:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called at least {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_at_most_calls_assertion(self, times: int) -> None:
self._set_max_calls(times)
def assertion() -> None:
if not self.call_count or self.call_count > times:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called at most {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_call_order_assertion(self) -> None:
global _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls
if not _call_order_assertion_registered:
def assertion() -> None:
if _received_ordered_calls != _expected_ordered_calls:
raise AssertionError(
(
"calls did not match assertion.\n"
"\n"
"These calls were expected to have happened in order:\n"
"\n"
"{}\n"
"\n"
"but these calls were made:\n"
"\n"
"{}"
).format(
"\n".join(
(
" {}, {} with {}".format(
_format_target(target),
repr(method),
runner._args_message().rstrip(),
)
for target, method, runner in _expected_ordered_calls
)
),
"\n".join(
(
" {}, {} with {}".format(
_format_target(target),
repr(method),
runner._args_message().rstrip(),
)
for target, method, runner in _received_ordered_calls
)
),
)
)
register_assertion(assertion)
_call_order_assertion_registered = True
_expected_ordered_calls.append((self.target, self.method, self))
self._has_order_assertion = True
class _Runner(_BaseRunner):
def run(self, *args: Any, **kwargs: Any) -> None:
super().register_call(*args, **kwargs)
class _AsyncRunner(_BaseRunner):
async def run(self, *args: Any, **kwargs: Any) -> None:
super().register_call(*args, **kwargs)
class _ReturnValueRunner(_Runner):
def __init__(
self,
target: Any,
method: str,
original_callable: Union[Callable, Mock],
value: Optional[Any],
allow_coro: bool = False,
) -> None:
super().__init__(target, method, original_callable)
if not allow_coro and _is_coroutine(value):
raise CoroutineValueError()
self.return_value = value
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super().run(*args, **kwargs)
return self.return_value
class _ReturnValuesRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
values_list: List[Any],
allow_coro: bool = False,
) -> None:
super(_ReturnValuesRunner, self).__init__(target, method, original_callable)
# Reverse original list for popping efficiency
if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
raise CoroutineValueError()
self.values_list = list(reversed(values_list))
def run(self, *args: Any, **kwargs: Any) -> Any:
super(_ReturnValuesRunner, self).run(*args, **kwargs)
if self.values_list:
return self.values_list.pop()
else:
raise UndefinedBehaviorForCall("No more values to return!")
class _YieldValuesRunner(_Runner):
TYPE_VALIDATION = False
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
values_list: List[Any],
allow_coro: bool = False,
) -> None:
super(_YieldValuesRunner, self).__init__(target, method, original_callable)
self.values_list = values_list
self.index = 0
if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
raise CoroutineValueError()
def __iter__(self) -> "_YieldValuesRunner":
return self
def __next__(self) -> Any:
try:
item = self.values_list[self.index]
except IndexError:
raise StopIteration()
self.index += 1
return item
def run(self, *args: Any, **kwargs: Any) -> "_YieldValuesRunner": # type: ignore
super(_YieldValuesRunner, self).run(*args, **kwargs)
return self
class _RaiseRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
exception: BaseException,
) -> None:
super(_RaiseRunner, self).__init__(target, method, original_callable)
self.exception = exception
def run(self, *args: Any, **kwargs: Any) -> None:
super(_RaiseRunner, self).run(*args, **kwargs)
raise self.exception
class _ImplementationRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
new_implementation: Callable,
allow_coro: bool = False,
) -> None:
super(_ImplementationRunner, self).__init__(target, method, original_callable)
self.new_implementation = new_implementation
self._allow_coro = allow_coro
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super(_ImplementationRunner, self).run(*args, **kwargs)
new_impl = self.new_implementation(*args, **kwargs)
if not self._allow_coro and _is_coroutine(new_impl):
raise CoroutineValueError()
return new_impl
class _AsyncImplementationRunner(_AsyncRunner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
new_implementation: Callable,
) -> None:
super().__init__(target, method, original_callable)
self.new_implementation = new_implementation
async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
await super().run(*args, **kwargs)
coro = self.new_implementation(*args, **kwargs)
if not _is_coroutine(coro):
raise NotACoroutine(
f"Function did not return a coroutine.\n"
f"{self.new_implementation} must return a coroutine."
)
return await coro
class _CallOriginalRunner(_Runner):
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super(_CallOriginalRunner, self).run(*args, **kwargs)
return self.original_callable(*args, **kwargs)
class _AsyncCallOriginalRunner(_AsyncRunner):
async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
await super().run(*args, **kwargs)
return await self.original_callable(*args, **kwargs)
##
## Callable Mocks
##
class _CallableMock:
def __init__(
self,
target: Any,
method: str,
caller_frame_info: Traceback,
is_async: bool = False,
callable_returns_coroutine: bool = False,
# type_validation accepted values:
# * None: type validation will be enabled except if target is a StrictMock
# with disabled type validation
# * True: type validation will be enabled (regardless of target type)
# * False: type validation will be disabled
type_validation: Optional[bool] = None,
) -> None:
self.target = target
self.method = method
self.runners: List[_BaseRunner] = []
self.is_async = is_async
self.callable_returns_coroutine = callable_returns_coroutine
self.type_validation = type_validation or type_validation is None
self.caller_frame_info = caller_frame_info
if type_validation is None and isinstance(target, StrictMock):
# If type validation is enabled on the specific call
# but the StrictMock has type validation disabled then
# type validation should be disabled
self.type_validation = target._type_validation
def _get_runner(self, *args: Any, **kwargs: Any) -> Any:
for runner in self.runners:
if runner.can_accept_args(*args, **kwargs):
return runner
return None
def _validate_return_type(self, runner: _BaseRunner, value: Any) -> None:
if self.type_validation and runner.TYPE_VALIDATION:
if runner.original_callable is not None:
_validate_return_type(
runner.original_callable,
value,
self.caller_frame_info,
self.callable_returns_coroutine,
)
elif isinstance(runner.target, StrictMock):
_validate_return_type(
getattr(runner.target, runner.method), value, self.caller_frame_info
)
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Any]:
runner = self._get_runner(*args, **kwargs)
if runner:
if self.is_async:
if isinstance(runner, _AsyncRunner):
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
value = await runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
value = async_wrapper(*args, **kwargs)
else:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
value = runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
value = async_wrapper(*args, **kwargs)
else:
value = runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
else:
ex_msg = (
"{}, {}:\n"
" Received call:\n"
"{}"
" But no behavior was defined for it."
).format(
_format_target(self.target),
repr(self.method),
_format_args(2, *args, **kwargs),
)
if self._registered_calls:
ex_msg += "\n These are the registered calls:\n" "{}".format(
"".join(
_format_args(2, *reg_args, **reg_kwargs)
for reg_args, reg_kwargs in self._registered_calls
)
)
raise UnexpectedCallArguments(ex_msg)
raise UndefinedBehaviorForCall(ex_msg)
@property
def _registered_calls(self) -> Any:
return [runner.accepted_args for runner in self.runners if runner.accepted_args]
##
## Support
##
class _MockCallableDSL:
CALLABLE_MOCKS: Dict[
Union[int, Tuple[int, str]], Union[Callable[[Type[object]], Any]]
] = {}
_NAME: str = "mock_callable"
def _validate_patch(
self,
name: str = "mock_callable",
other_name: str = "mock_async_callable",
coroutine_function: bool = False,
callable_returns_coroutine: bool = False,
) -> None:
if self._method == "__new__":
raise ValueError(
f"Mocking __new__ is not allowed with {name}(), please use "
"mock_constructor()."
)
_bail_if_private(self._method, self.allow_private)
if isinstance(self._target, StrictMock):
template_value = getattr(self._target._template, self._method, None)
if template_value and callable(template_value):
if not coroutine_function and asyncio.iscoroutinefunction(
template_value
):
raise ValueError(
f"{name}() can not be used with coroutine functions.\n"
f"The attribute '{self._method}' of the template class "
f"of {self._target} is a coroutine function. You can "
f"use {other_name}() instead."
)
if coroutine_function and not (
_is_coroutinefunction(template_value) or callable_returns_coroutine
):
raise ValueError(
f"{name}() can not be used with non coroutine "
"functions.\n"
f"The attribute '{self._method}' of the template class "
f"of {self._target} is not a coroutine function. You "
f"can use {other_name}() instead."
)
else:
if inspect.isclass(self._target) and _is_instance_method(
self._target, self._method
):
raise ValueError(
"Patching an instance method at the class is not supported: "
"bugs are easy to introduce, as patch is not scoped for an "
"instance, which can potentially even break class behavior; "
"assertions on calls are ambiguous (for every instance or one "
"global assertion?)."
)
original_callable = getattr(self._target, self._method)
if not callable(original_callable):
raise ValueError(
f"{name}() can only be used with callable attributes and "
f"{repr(original_callable)} is not."
)
if inspect.isclass(original_callable):
raise ValueError(
f"{name}() can not be used with with classes: "
f"{repr(original_callable)}. Perhaps you want to use "
"mock_constructor() instead."
)
if not coroutine_function and asyncio.iscoroutinefunction(
original_callable
):
raise ValueError(
f"{name}() can not be used with coroutine functions.\n"
f"{original_callable} is a coroutine function. You can use "
f"{other_name}() instead."
)
if coroutine_function and not (
_is_coroutinefunction(original_callable) or callable_returns_coroutine
):
raise ValueError(
f"{name}() can not be used with non coroutine functions.\n"
f"{original_callable} is not a coroutine function. You can "
f"use {other_name}() instead."
)
def _patch(
self, new_value: Union[Callable, _CallableMock]
) -> Union[Tuple[Callable, Callable], Tuple[Mock, Callable], Tuple[None, Callable]]:
self._validate_patch()
if isinstance(self._target, StrictMock):
original_callable = None
else:
original_callable = getattr(self._target, self._method)
new_value = _wrap_signature_and_type_validation(
new_value,
self._target,
self._method,
self.type_validation or self.type_validation is None,
)
restore = self._method in self._target.__dict__
restore_value = self._target.__dict__.get(self._method, None)
if inspect.isclass(self._target):
new_value = staticmethod(new_value) # type: ignore
unpatcher = _patch(
self._target, self._method, new_value, restore, restore_value
)
return original_callable, unpatcher
def _get_callable_mock(self) -> _CallableMock:
return _CallableMock(
self._original_target,
self._method,
self.caller_frame_info,
type_validation=self.type_validation,
)
def __init__(
self,
target: Any,
method: str,
caller_frame_info: Traceback,
callable_mock: Union[Callable[[Type[object]], Any], _CallableMock, None] = None,
original_callable: Optional[Callable] = None,
allow_private: bool = False,
type_validation: Optional[bool] = None,
) -> None:
if not _is_setup():
raise RuntimeError(
"TestSlide was not correctly setup before usage!\n"
"This error happens when mock_callable, mock_async_callable or "
"mock_constructor are attempted to be used without correct "
"test framework integration, meaning unpatching and "
"assertions will not work as expected.\n"
"A common scenario for this is when testslide.TestCase is "
"subclassed with setUp() overridden but super().setUp() was not "
"called."
)
self._original_target = target
self._method = method
self._runner: Optional[_BaseRunner] = None
self._next_runner_accepted_args: Any = None
self.allow_private = allow_private
self.type_validation = type_validation
self.caller_frame_info = caller_frame_info
self._allow_coro = False
self._accept_partial_call = False
if isinstance(target, str):
self._target = testslide._importer(target)
else:
self._target = target
target_method_id = (id(self._target), method)
if target_method_id not in self.CALLABLE_MOCKS:
if not callable_mock:
patch = True
callable_mock = self._get_callable_mock()
else:
patch = False
self.CALLABLE_MOCKS[target_method_id] = callable_mock
self._callable_mock = callable_mock
def del_callable_mock() -> None:
del self.CALLABLE_MOCKS[target_method_id]
_unpatchers.append(del_callable_mock)
if patch:
original_callable, unpatcher = self._patch(callable_mock)
_unpatchers.append(unpatcher)
self._original_callable = original_callable
callable_mock.original_callable = original_callable # type: ignore
else:
self._callable_mock = self.CALLABLE_MOCKS[target_method_id]
self._original_callable = self._callable_mock.original_callable # type: ignore
def _add_runner(self, runner: _BaseRunner) -> None:
if self._runner:
raise ValueError(
"Can't define more than one behavior using the same "
"self.mock_callable() chain. Please call self.mock_callable() again "
"one time for each new behavior."
)
if self._next_runner_accepted_args:
args, kwargs = self._next_runner_accepted_args
self._next_runner_accepted_args = None
runner.add_accepted_args(self._accept_partial_call, *args, **kwargs)
self._accept_partial_call = False
self._runner = runner
self._callable_mock.runners.insert(0, runner) # type: ignore
def _assert_runner(self) -> None:
if not self._runner:
raise ValueError(
"You must first define a behavior. Eg: "
"self.mock_callable(target, 'func')"
".to_return_value(value)"
".and_assert_called_exactly(times)"
)
if self._runner._call_count > 0:
raise ValueError(
f"No extra configuration is allowed after {self._NAME} "
f"receives its first call, it received {self._runner._call_count} "
f"call{'s' if self._runner._call_count > 1 else ''} already. "
"You should instead define it all at once, "
f"eg: self.{self._NAME}(target, 'func')"
".to_return_value(value).and_assert_called_once()"
)
##
## Arguments
##
def for_call(
self, *args: Any, **kwargs: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Filter for only calls like this.
"""
if self._runner:
self._runner.add_accepted_args(False, *args, **kwargs)
else:
self._next_runner_accepted_args = (args, kwargs)
return self
def for_partial_call(
self, *args: Any, **kwargs: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
if self._runner:
self._runner.add_accepted_args(True, *args, **kwargs)
else:
self._accept_partial_call = True
self._next_runner_accepted_args = (args, kwargs)
return self
##
## Behavior
##
def to_return_value(
self, value: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Always return given value.
"""
self._add_runner(
_ReturnValueRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
value,
self._allow_coro,
)
)
return self
def to_return_values(
self, values_list: List[Any]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
For each call, return each value from given list in order.
When list is exhausted, goes to the next behavior set.
"""
if not isinstance(values_list, list):
raise ValueError("{} is not a list".format(values_list))
self._add_runner(
_ReturnValuesRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
values_list,
self._allow_coro,
)
)
return self
def to_yield_values(
self, values_list: List[Any]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Callable will return an iterator what will yield each value from the
given list.
"""
if not isinstance(values_list, list):
raise ValueError("{} is not a list".format(values_list))
self._add_runner(
_YieldValuesRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
values_list,
self._allow_coro,
)
)
return self
def to_raise(
self, ex: Union[Type[BaseException], BaseException]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Raises given exception class or exception instance.
"""
if isinstance(ex, BaseException):
self._add_runner(
_RaiseRunner(
self._original_target, self._method, self._original_callable, ex # type: ignore
)
)
elif isinstance(ex(), BaseException):
self._add_runner(
_RaiseRunner(
self._original_target, self._method, self._original_callable, ex() # type: ignore
)
)
else:
raise ValueError(
"{} is not subclass or instance of BaseException".format(ex)
)
return self
def with_implementation(
self, func: Callable
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Replace callable by given function.
"""
if not callable(func):
raise ValueError("{} must be callable.".format(func))
self._add_runner(
_ImplementationRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
func,
self._allow_coro,
)
)
return self
def with_wrapper(
self, func: Callable
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Replace callable with given wrapper function, that will be called as:
func(original_func, *args, **kwargs)
receiving the original function as the first argument as well as any given
arguments.
"""
if not callable(func):
raise ValueError("{} must be callable.".format(func))
if not self._original_callable:
raise ValueError("Can not wrap original callable that does not exist.")
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(self._original_callable, *args, **kwargs)
self._add_runner(
_ImplementationRunner(
self._original_target,
self._method,
self._original_callable,
wrapper,
self._allow_coro,
)
)
return self
def to_call_original(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Calls the original callable implementation, instead of mocking it. This is
useful for example, if you want to by default call the original implementation,
but for a specific calls, mock the result.
"""
if not self._original_callable:
raise ValueError("Can not call original callable that does not exist.")
self._add_runner(
_CallOriginalRunner(
self._original_target, self._method, self._original_callable
)
)
return self
##
## Call Assertions
##
def and_assert_called_exactly(
self, count: int
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Assert that there were exactly the given number of calls.
If assertion is for 0 calls, any received call will raise
UnexpectedCallReceived and also an AssertionError.
"""
if count:
self._assert_runner()
else:
if not self._runner:
self.to_raise(
UnexpectedCallReceived(
("{}, {}: Expected not to be called!").format(
_format_target(self._target), repr(self._method)
)
)
)
self._runner.add_exact_calls_assertion(count) # type: ignore
return self
def and_assert_not_called(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Short for and_assert_called_exactly(0)
"""
return self.and_assert_called_exactly(0)
def and_assert_called_once(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Short for and_assert_called_exactly(1)
"""
return self.and_assert_called_exactly(1)
def and_assert_called_twice(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Short for and_assert_called_exactly(2)
"""
return self.and_assert_called_exactly(2)
def and_assert_called_at_least(
self, count: int
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Assert that there at least the given number of calls.
"""
if count < 1:
raise ValueError("times must be >= 1")
self._assert_runner()
self._runner.add_at_least_calls_assertion(count) # type: ignore
return self
def and_assert_called_at_most(
self, count: int
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Assert that there at most the given number of calls.
"""
if count < 1:
raise ValueError("times must be >= 1")
self._assert_runner()
self._runner.add_at_most_calls_assertion(count) # type: ignore
return self
def and_assert_called(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Short for self.and_assert_called_at_least(1).
"""
return self.and_assert_called_at_least(1)
def and_assert_called_ordered(
self,
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Assert that multiple calls, potentially to different mock_callable()
targets, happened in the order defined.
"""
self._assert_runner()
self._runner.add_call_order_assertion() # type: ignore
return self
class _MockAsyncCallableDSL(_MockCallableDSL):
_NAME: str = "mock_async_callable"
def __init__(
self,
target: Union[str, type],
method: str,
caller_frame_info: Traceback,
callable_returns_coroutine: bool,
allow_private: bool = False,
type_validation: bool = True,
) -> None:
self._callable_returns_coroutine = callable_returns_coroutine
super().__init__(
target,
method,
caller_frame_info,
allow_private=allow_private,
type_validation=type_validation,
)
self._allow_coro = True
def _validate_patch(self) -> None: # type: ignore
return super()._validate_patch(
name=self._NAME,
other_name="mock_callable",
coroutine_function=True,
callable_returns_coroutine=self._callable_returns_coroutine,
)
def _get_callable_mock(self) -> _CallableMock:
return _CallableMock(
self._original_target,
self._method,
self.caller_frame_info,
is_async=True,
callable_returns_coroutine=self._callable_returns_coroutine,
type_validation=self.type_validation,
)
def with_implementation(self, func: Callable) -> "_MockAsyncCallableDSL":
"""
Replace callable by given async function.
"""
if not callable(func):
raise ValueError("{} must be callable.".format(func))
self._add_runner(
_AsyncImplementationRunner(
self._original_target, self._method, self._original_callable, func # type: ignore
)
)
return self
def with_wrapper(self, func: Callable) -> "_MockAsyncCallableDSL":
"""
Replace callable with given wrapper async function, that will be called as:
await func(original_async_func, *args, **kwargs)
receiving the original function as the first argument as well as any given
arguments.
"""
if not callable(func):
raise ValueError("{} must be callable.".format(func))
if not self._original_callable:
raise ValueError("Can not wrap original callable that does not exist.")
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
coro = func(self._original_callable, *args, **kwargs)
if not _is_coroutine(coro):
raise NotACoroutine(
f"Function did not return a coroutine.\n"
f"{func} must return a coroutine."
)
return await coro
self._add_runner(
_AsyncImplementationRunner(
self._original_target, self._method, self._original_callable, wrapper
)
)
return self
def to_call_original(self) -> "_MockAsyncCallableDSL":
"""
Calls the original callable implementation, instead of mocking it. This is
useful for example, if you want to by default call the original implementation,
but for a specific calls, mock the result.
"""
if not self._original_callable:
raise ValueError("Can not call original callable that does not exist.")
self._add_runner(
_AsyncCallOriginalRunner(
self._original_target, self._method, self._original_callable
)
)
return self