# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import functools
import inspect
from inspect import Traceback
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)
from unittest.mock import Mock

import testslide
from testslide.lib import _validate_return_type, _wrap_signature_and_type_validation
from testslide.strict_mock import StrictMock

from .lib import CoroutineValueError, _bail_if_private, _is_a_builtin
from .patch import _is_instance_method, _patch

if TYPE_CHECKING:
    from testslide.matchers import RegexMatches  # noqa: F401
    from testslide.mock_constructor import _MockConstructorDSL  # noqa: F401


def mock_callable(
    target: Any,
    method: str,
    allow_private: bool = False,
    # type_validation accepted values:
    #  * None:  type validation will be enabled except if target is a StrictMock
    #           with disabled type validation
    #  * True:  type validation will be enabled (regardless of target type)
    #  * False:  type validation will be disabled
    type_validation: Optional[bool] = None,
) -> "_MockCallableDSL":
    caller_frame = inspect.currentframe().f_back.f_back  # type: ignore
    # loading the context ends up reading files from disk and that might block
    # the event loop, so we don't do it.
    caller_frame_info = inspect.getframeinfo(caller_frame, context=0)  # type: ignore
    return _MockCallableDSL(
        target,
        method,
        caller_frame_info,
        allow_private=allow_private,
        type_validation=type_validation,
    )


def mock_async_callable(
    target: Union[type, str],
    method: str,
    callable_returns_coroutine: bool = False,
    allow_private: bool = False,
    type_validation: bool = True,
) -> "_MockAsyncCallableDSL":
    caller_frame = inspect.currentframe().f_back  # type: ignore
    # loading the context ends up reading files from disk and that might block
    # the event loop, so we don't do it.
    caller_frame_info = inspect.getframeinfo(caller_frame, context=0)  # type: ignore
    return _MockAsyncCallableDSL(
        target,
        method,
        caller_frame_info,
        callable_returns_coroutine,
        allow_private,
        type_validation,
    )


_unpatchers: List[Callable] = []  # noqa T484


def _default_register_assertion(assertion: Callable) -> None:
    """
    This method must be redefined by the test framework using mock_callable().
    It will be called when a new assertion is defined, passing a callable as an
    argument that evaluates that assertion. Every defined assertion during a test
    must be called after the test code ends, and before the test finishes.
    """
    raise NotImplementedError("This method must be redefined by the test framework")


register_assertion = _default_register_assertion
_call_order_assertion_registered: bool = False
_received_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []
_expected_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []


def unpatch_all_callable_mocks() -> None:
    """
    This method must be called after every test unconditionally to remove all
    active mock_callable() patches.
    """
    global register_assertion, _default_register_assertion, _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls

    register_assertion = _default_register_assertion
    _call_order_assertion_registered = False
    del _received_ordered_calls[:]
    del _expected_ordered_calls[:]

    unpatch_exceptions = []
    for unpatcher in _unpatchers:
        try:
            unpatcher()
        except Exception as e:
            unpatch_exceptions.append(e)
    del _unpatchers[:]
    if unpatch_exceptions:
        raise RuntimeError(
            "Exceptions raised when unpatching: {}".format(unpatch_exceptions)
        )


def _is_setup() -> bool:
    global register_assertion, _default_register_assertion
    return register_assertion is not _default_register_assertion


def _format_target(target: Union[str, type]) -> str:
    if hasattr(target, "__repr__"):
        return repr(target)
    else:
        return "{}.{} instance with id {}".format(
            target.__module__, type(target).__name__, id(target)
        )


def _format_args(indent: int, *args: Any, **kwargs: Any) -> str:
    indentation = "  " * indent
    s = ""
    if args:
        s += ("{}{}\n").format(indentation, args)
    if kwargs:
        s += indentation + "{"
        if kwargs:
            s += "\n"
            for k in sorted(kwargs.keys()):
                s += "{}  {}={},\n".format(indentation, k, repr(kwargs[k]))
            s += "{}".format(indentation)
        s += "}\n"
    return s


def _is_coroutine(obj: Any) -> bool:

    return inspect.iscoroutine(obj) or isinstance(obj, asyncio.coroutines.CoroWrapper)  # type: ignore


def _is_coroutinefunction(func: Any) -> bool:
    # We use asyncio.iscoroutinefunction over inspect because the next Cython version
    # will return True from the asyncio variant over inspect which will return False
    # FIXME We can not reliably introspect coroutine functions
    # for builtins: https://bugs.python.org/issue38225
    return asyncio.iscoroutinefunction(func) or _is_a_builtin(func)


##
## Exceptions
##


class UndefinedBehaviorForCall(BaseException):
    """
    Raised when a mock receives a call for which no behavior was defined.

    Inherits from BaseException to avoid being caught by tested code.
    """


class UnexpectedCallReceived(BaseException):
    """
    Raised when a mock receives a call that it is configured not to accept.

    Inherits from BaseException to avoid being caught by tested code.
    """


class UnexpectedCallArguments(BaseException):
    """
    Raised when a mock receives a call with unexpected arguments.

    Inherits from BaseException to avoid being caught by tested code.
    """


class NotACoroutine(BaseException):
    """
    Raised when a mock that requires a coroutine is not mocked with one.

    Inherits from BaseException to avoid being caught by tested code.
    """


##
## Runners
##


class _BaseRunner:
    TYPE_VALIDATION = True

    def __init__(
        self, target: Any, method: str, original_callable: Union[Callable, Mock]
    ) -> None:
        self.target = target
        self.method = method
        self.original_callable = original_callable
        self.accepted_args: Optional[Tuple[Any, Any]] = None

        self._call_count: int = 0
        self._max_calls: Optional[int] = None
        self._has_order_assertion = False
        self._accept_partial_call = False

    def register_call(self, *args: Any, **kwargs: Any) -> None:
        global _received_ordered_calls

        if self._has_order_assertion:
            _received_ordered_calls.append((self.target, self.method, self))

        self.inc_call_count()

    @property
    def call_count(self) -> int:
        return self._call_count

    @property
    def max_calls(self) -> Optional[int]:
        return self._max_calls

    def _set_max_calls(self, times: int) -> None:
        if not self._max_calls or times < self._max_calls:
            self._max_calls = times

    def inc_call_count(self) -> None:
        self._call_count += 1
        if self.max_calls and self._call_count > self.max_calls:
            raise UnexpectedCallReceived(
                (
                    "Unexpected call received.\n"
                    "{}, {}:\n"
                    "  expected to receive at most {} calls with {}"
                    "  but received an extra call."
                ).format(
                    _format_target(self.target),
                    repr(self.method),
                    self.max_calls,
                    self._args_message(),
                )
            )

    def add_accepted_args(
        self,
        _accept_partial_call: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        self.accepted_args = (args, kwargs)
        self._accept_partial_call = _accept_partial_call

    def can_accept_args(self, *args: Any, **kwargs: Any) -> bool:
        if self.accepted_args:
            if self._accept_partial_call:
                args_match = all(
                    any(elem == arg for arg in args) for elem in self.accepted_args[0]
                )
                kwargs_match = all(
                    elem in kwargs.keys()
                    and kwargs[elem] == self.accepted_args[1][elem]
                    for elem in self.accepted_args[1].keys()
                )
                return args_match and kwargs_match
            else:
                return self.accepted_args == (args, kwargs)
        else:
            return True

    def _args_message(self) -> str:
        if self.accepted_args:
            return "arguments:\n{}".format(
                _format_args(2, *self.accepted_args[0], **self.accepted_args[1])
            )
        else:
            return "any arguments "

    def add_exact_calls_assertion(self, times: int) -> None:
        self._set_max_calls(times)

        def assertion() -> None:
            if times != self.call_count:
                raise AssertionError(
                    (
                        "calls did not match assertion.\n"
                        "{}, {}:\n"
                        "  expected: called exactly {} time(s) with {}"
                        "  received: {} call(s)"
                    ).format(
                        _format_target(self.target),
                        repr(self.method),
                        times,
                        self._args_message(),
                        self.call_count,
                    )
                )

        register_assertion(assertion)

    def add_at_least_calls_assertion(self, times: int) -> None:
        def assertion() -> None:
            if self.call_count < times:
                raise AssertionError(
                    (
                        "calls did not match assertion.\n"
                        "{}, {}:\n"
                        "  expected: called at least {} time(s) with {}"
                        "  received: {} call(s)"
                    ).format(
                        _format_target(self.target),
                        repr(self.method),
                        times,
                        self._args_message(),
                        self.call_count,
                    )
                )

        register_assertion(assertion)

    def add_at_most_calls_assertion(self, times: int) -> None:
        self._set_max_calls(times)

        def assertion() -> None:
            if not self.call_count or self.call_count > times:
                raise AssertionError(
                    (
                        "calls did not match assertion.\n"
                        "{}, {}:\n"
                        "  expected: called at most {} time(s) with {}"
                        "  received: {} call(s)"
                    ).format(
                        _format_target(self.target),
                        repr(self.method),
                        times,
                        self._args_message(),
                        self.call_count,
                    )
                )

        register_assertion(assertion)

    def add_call_order_assertion(self) -> None:
        global _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls

        if not _call_order_assertion_registered:

            def assertion() -> None:
                if _received_ordered_calls != _expected_ordered_calls:
                    raise AssertionError(
                        (
                            "calls did not match assertion.\n"
                            "\n"
                            "These calls were expected to have happened in order:\n"
                            "\n"
                            "{}\n"
                            "\n"
                            "but these calls were made:\n"
                            "\n"
                            "{}"
                        ).format(
                            "\n".join(
                                (
                                    "  {}, {} with {}".format(
                                        _format_target(target),
                                        repr(method),
                                        runner._args_message().rstrip(),
                                    )
                                    for target, method, runner in _expected_ordered_calls
                                )
                            ),
                            "\n".join(
                                (
                                    "  {}, {} with {}".format(
                                        _format_target(target),
                                        repr(method),
                                        runner._args_message().rstrip(),
                                    )
                                    for target, method, runner in _received_ordered_calls
                                )
                            ),
                        )
                    )

            register_assertion(assertion)
            _call_order_assertion_registered = True

        _expected_ordered_calls.append((self.target, self.method, self))
        self._has_order_assertion = True


class _Runner(_BaseRunner):
    def run(self, *args: Any, **kwargs: Any) -> None:
        super().register_call(*args, **kwargs)


class _AsyncRunner(_BaseRunner):
    async def run(self, *args: Any, **kwargs: Any) -> None:
        super().register_call(*args, **kwargs)


class _ReturnValueRunner(_Runner):
    def __init__(
        self,
        target: Any,
        method: str,
        original_callable: Union[Callable, Mock],
        value: Optional[Any],
        allow_coro: bool = False,
    ) -> None:
        super().__init__(target, method, original_callable)
        if not allow_coro and _is_coroutine(value):
            raise CoroutineValueError()
        self.return_value = value

    def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        super().run(*args, **kwargs)
        return self.return_value


class _ReturnValuesRunner(_Runner):
    def __init__(
        self,
        target: Union[type, str],
        method: str,
        original_callable: Union[Callable[..., Any], Mock],
        values_list: List[Any],
        allow_coro: bool = False,
    ) -> None:
        super(_ReturnValuesRunner, self).__init__(target, method, original_callable)
        # Reverse original list for popping efficiency
        if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
            raise CoroutineValueError()
        self.values_list = list(reversed(values_list))

    def run(self, *args: Any, **kwargs: Any) -> Any:
        super(_ReturnValuesRunner, self).run(*args, **kwargs)
        if self.values_list:
            return self.values_list.pop()
        else:
            raise UndefinedBehaviorForCall("No more values to return!")


class _YieldValuesRunner(_Runner):
    TYPE_VALIDATION = False

    def __init__(
        self,
        target: Union[type, str],
        method: str,
        original_callable: Union[Callable[..., Any], Mock],
        values_list: List[Any],
        allow_coro: bool = False,
    ) -> None:
        super(_YieldValuesRunner, self).__init__(target, method, original_callable)
        self.values_list = values_list
        self.index = 0
        if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
            raise CoroutineValueError()

    def __iter__(self) -> "_YieldValuesRunner":
        return self

    def __next__(self) -> Any:
        try:
            item = self.values_list[self.index]
        except IndexError:
            raise StopIteration()
        self.index += 1
        return item

    def run(self, *args: Any, **kwargs: Any) -> "_YieldValuesRunner":  # type: ignore
        super(_YieldValuesRunner, self).run(*args, **kwargs)
        return self


class _RaiseRunner(_Runner):
    def __init__(
        self,
        target: Union[type, str],
        method: str,
        original_callable: Union[Callable[..., Any], Mock],
        exception: BaseException,
    ) -> None:
        super(_RaiseRunner, self).__init__(target, method, original_callable)
        self.exception = exception

    def run(self, *args: Any, **kwargs: Any) -> None:
        super(_RaiseRunner, self).run(*args, **kwargs)
        raise self.exception


class _ImplementationRunner(_Runner):
    def __init__(
        self,
        target: Union[type, str],
        method: str,
        original_callable: Union[Callable[..., Any], Mock],
        new_implementation: Callable,
        allow_coro: bool = False,
    ) -> None:
        super(_ImplementationRunner, self).__init__(target, method, original_callable)
        self.new_implementation = new_implementation
        self._allow_coro = allow_coro

    def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        super(_ImplementationRunner, self).run(*args, **kwargs)
        new_impl = self.new_implementation(*args, **kwargs)
        if not self._allow_coro and _is_coroutine(new_impl):
            raise CoroutineValueError()
        return new_impl


class _AsyncImplementationRunner(_AsyncRunner):
    def __init__(
        self,
        target: Union[type, str],
        method: str,
        original_callable: Union[Callable[..., Any], Mock],
        new_implementation: Callable,
    ) -> None:
        super().__init__(target, method, original_callable)
        self.new_implementation = new_implementation

    async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        await super().run(*args, **kwargs)
        coro = self.new_implementation(*args, **kwargs)
        if not _is_coroutine(coro):
            raise NotACoroutine(
                f"Function did not return a coroutine.\n"
                f"{self.new_implementation} must return a coroutine."
            )
        return await coro


class _CallOriginalRunner(_Runner):
    def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        super(_CallOriginalRunner, self).run(*args, **kwargs)
        return self.original_callable(*args, **kwargs)


class _AsyncCallOriginalRunner(_AsyncRunner):
    async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        await super().run(*args, **kwargs)
        return await self.original_callable(*args, **kwargs)


##
## Callable Mocks
##


class _CallableMock:
    def __init__(
        self,
        target: Any,
        method: str,
        caller_frame_info: Traceback,
        is_async: bool = False,
        callable_returns_coroutine: bool = False,
        # type_validation accepted values:
        #  * None:  type validation will be enabled except if target is a StrictMock
        #           with disabled type validation
        #  * True:  type validation will be enabled (regardless of target type)
        #  * False:  type validation will be disabled
        type_validation: Optional[bool] = None,
    ) -> None:
        self.target = target
        self.method = method
        self.runners: List[_BaseRunner] = []
        self.is_async = is_async
        self.callable_returns_coroutine = callable_returns_coroutine
        self.type_validation = type_validation or type_validation is None
        self.caller_frame_info = caller_frame_info

        if type_validation is None and isinstance(target, StrictMock):
            # If type validation is enabled on the specific call
            # but the StrictMock has type validation disabled then
            # type validation should be disabled
            self.type_validation = target._type_validation

    def _get_runner(self, *args: Any, **kwargs: Any) -> Any:
        for runner in self.runners:
            if runner.can_accept_args(*args, **kwargs):
                return runner
        return None

    def _validate_return_type(self, runner: _BaseRunner, value: Any) -> None:
        if self.type_validation and runner.TYPE_VALIDATION:
            if runner.original_callable is not None:
                _validate_return_type(
                    runner.original_callable,
                    value,
                    self.caller_frame_info,
                    self.callable_returns_coroutine,
                )
            elif isinstance(runner.target, StrictMock):
                _validate_return_type(
                    getattr(runner.target, runner.method), value, self.caller_frame_info
                )

    def __call__(self, *args: Any, **kwargs: Any) -> Optional[Any]:
        runner = self._get_runner(*args, **kwargs)
        if runner:
            if self.is_async:
                if isinstance(runner, _AsyncRunner):

                    async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
                        value = await runner.run(*args, **kwargs)
                        self._validate_return_type(runner, value)
                        return value

                    value = async_wrapper(*args, **kwargs)
                else:

                    async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
                        value = runner.run(*args, **kwargs)
                        self._validate_return_type(runner, value)
                        return value

                    value = async_wrapper(*args, **kwargs)
            else:
                value = runner.run(*args, **kwargs)
                self._validate_return_type(runner, value)
            return value
        else:
            ex_msg = (
                "{}, {}:\n"
                "  Received call:\n"
                "{}"
                "  But no behavior was defined for it."
            ).format(
                _format_target(self.target),
                repr(self.method),
                _format_args(2, *args, **kwargs),
            )
            if self._registered_calls:
                ex_msg += "\n  These are the registered calls:\n" "{}".format(
                    "".join(
                        _format_args(2, *reg_args, **reg_kwargs)
                        for reg_args, reg_kwargs in self._registered_calls
                    )
                )
                raise UnexpectedCallArguments(ex_msg)
            raise UndefinedBehaviorForCall(ex_msg)

    @property
    def _registered_calls(self) -> Any:
        return [runner.accepted_args for runner in self.runners if runner.accepted_args]


##
## Support
##


class _MockCallableDSL:

    CALLABLE_MOCKS: Dict[
        Union[int, Tuple[int, str]], Union[Callable[[Type[object]], Any]]
    ] = {}
    _NAME: str = "mock_callable"

    def _validate_patch(
        self,
        name: str = "mock_callable",
        other_name: str = "mock_async_callable",
        coroutine_function: bool = False,
        callable_returns_coroutine: bool = False,
    ) -> None:
        if self._method == "__new__":
            raise ValueError(
                f"Mocking __new__ is not allowed with {name}(), please use "
                "mock_constructor()."
            )
        _bail_if_private(self._method, self.allow_private)
        if isinstance(self._target, StrictMock):
            template_value = getattr(self._target._template, self._method, None)
            if template_value and callable(template_value):
                if not coroutine_function and asyncio.iscoroutinefunction(
                    template_value
                ):
                    raise ValueError(
                        f"{name}() can not be used with coroutine functions.\n"
                        f"The attribute '{self._method}' of the template class "
                        f"of {self._target} is a coroutine function. You can "
                        f"use {other_name}() instead."
                    )
                if coroutine_function and not (
                    _is_coroutinefunction(template_value) or callable_returns_coroutine
                ):
                    raise ValueError(
                        f"{name}() can not be used with non coroutine "
                        "functions.\n"
                        f"The attribute '{self._method}' of the template class "
                        f"of {self._target} is not a coroutine function. You "
                        f"can use {other_name}() instead."
                    )
        else:
            if inspect.isclass(self._target) and _is_instance_method(
                self._target, self._method
            ):
                raise ValueError(
                    "Patching an instance method at the class is not supported: "
                    "bugs are easy to introduce, as patch is not scoped for an "
                    "instance, which can potentially even break class behavior; "
                    "assertions on calls are ambiguous (for every instance or one "
                    "global assertion?)."
                )
            original_callable = getattr(self._target, self._method)
            if not callable(original_callable):
                raise ValueError(
                    f"{name}() can only be used with callable attributes and "
                    f"{repr(original_callable)} is not."
                )
            if inspect.isclass(original_callable):
                raise ValueError(
                    f"{name}() can not be used with with classes: "
                    f"{repr(original_callable)}. Perhaps you want to use "
                    "mock_constructor() instead."
                )
            if not coroutine_function and asyncio.iscoroutinefunction(
                original_callable
            ):
                raise ValueError(
                    f"{name}() can not be used with coroutine functions.\n"
                    f"{original_callable} is a coroutine function. You can use "
                    f"{other_name}() instead."
                )
            if coroutine_function and not (
                _is_coroutinefunction(original_callable) or callable_returns_coroutine
            ):
                raise ValueError(
                    f"{name}() can not be used with non coroutine functions.\n"
                    f"{original_callable} is not a coroutine function. You can "
                    f"use {other_name}() instead."
                )

    def _patch(
        self, new_value: Union[Callable, _CallableMock]
    ) -> Union[Tuple[Callable, Callable], Tuple[Mock, Callable], Tuple[None, Callable]]:
        self._validate_patch()

        if isinstance(self._target, StrictMock):
            original_callable = None
        else:
            original_callable = getattr(self._target, self._method)

        new_value = _wrap_signature_and_type_validation(
            new_value,
            self._target,
            self._method,
            self.type_validation or self.type_validation is None,
        )

        restore = self._method in self._target.__dict__
        restore_value = self._target.__dict__.get(self._method, None)

        if inspect.isclass(self._target):
            new_value = staticmethod(new_value)  # type: ignore

        unpatcher = _patch(
            self._target, self._method, new_value, restore, restore_value
        )

        return original_callable, unpatcher

    def _get_callable_mock(self) -> _CallableMock:
        return _CallableMock(
            self._original_target,
            self._method,
            self.caller_frame_info,
            type_validation=self.type_validation,
        )

    def __init__(
        self,
        target: Any,
        method: str,
        caller_frame_info: Traceback,
        callable_mock: Union[Callable[[Type[object]], Any], _CallableMock, None] = None,
        original_callable: Optional[Callable] = None,
        allow_private: bool = False,
        type_validation: Optional[bool] = None,
    ) -> None:
        if not _is_setup():
            raise RuntimeError(
                "TestSlide was not correctly setup before usage!\n"
                "This error happens when mock_callable, mock_async_callable or "
                "mock_constructor are attempted to be used without correct "
                "test framework integration, meaning unpatching and "
                "assertions will not work as expected.\n"
                "A common scenario for this is when testslide.TestCase is "
                "subclassed with setUp() overridden but super().setUp() was not "
                "called."
            )
        self._original_target = target
        self._method = method
        self._runner: Optional[_BaseRunner] = None
        self._next_runner_accepted_args: Any = None
        self.allow_private = allow_private
        self.type_validation = type_validation
        self.caller_frame_info = caller_frame_info
        self._allow_coro = False
        self._accept_partial_call = False
        if isinstance(target, str):
            self._target = testslide._importer(target)
        else:
            self._target = target

        target_method_id = (id(self._target), method)

        if target_method_id not in self.CALLABLE_MOCKS:
            if not callable_mock:
                patch = True
                callable_mock = self._get_callable_mock()
            else:
                patch = False
            self.CALLABLE_MOCKS[target_method_id] = callable_mock
            self._callable_mock = callable_mock

            def del_callable_mock() -> None:
                del self.CALLABLE_MOCKS[target_method_id]

            _unpatchers.append(del_callable_mock)

            if patch:
                original_callable, unpatcher = self._patch(callable_mock)
                _unpatchers.append(unpatcher)
            self._original_callable = original_callable
            callable_mock.original_callable = original_callable  # type: ignore
        else:
            self._callable_mock = self.CALLABLE_MOCKS[target_method_id]
            self._original_callable = self._callable_mock.original_callable  # type: ignore

    def _add_runner(self, runner: _BaseRunner) -> None:
        if self._runner:
            raise ValueError(
                "Can't define more than one behavior using the same "
                "self.mock_callable() chain. Please call self.mock_callable() again "
                "one time for each new behavior."
            )
        if self._next_runner_accepted_args:
            args, kwargs = self._next_runner_accepted_args
            self._next_runner_accepted_args = None
            runner.add_accepted_args(self._accept_partial_call, *args, **kwargs)
            self._accept_partial_call = False
        self._runner = runner
        self._callable_mock.runners.insert(0, runner)  # type: ignore

    def _assert_runner(self) -> None:
        if not self._runner:
            raise ValueError(
                "You must first define a behavior. Eg: "
                "self.mock_callable(target, 'func')"
                ".to_return_value(value)"
                ".and_assert_called_exactly(times)"
            )
        if self._runner._call_count > 0:
            raise ValueError(
                f"No extra configuration is allowed after {self._NAME} "
                f"receives its first call, it received {self._runner._call_count} "
                f"call{'s' if self._runner._call_count > 1 else ''} already. "
                "You should instead define it all at once, "
                f"eg: self.{self._NAME}(target, 'func')"
                ".to_return_value(value).and_assert_called_once()"
            )

    ##
    ## Arguments
    ##

    def for_call(
        self, *args: Any, **kwargs: Any
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Filter for only calls like this.
        """
        if self._runner:
            self._runner.add_accepted_args(False, *args, **kwargs)
        else:
            self._next_runner_accepted_args = (args, kwargs)
        return self

    def for_partial_call(
        self, *args: Any, **kwargs: Any
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        if self._runner:
            self._runner.add_accepted_args(True, *args, **kwargs)
        else:
            self._accept_partial_call = True
            self._next_runner_accepted_args = (args, kwargs)
        return self

    ##
    ## Behavior
    ##

    def to_return_value(
        self, value: Any
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Always return given value.
        """
        self._add_runner(
            _ReturnValueRunner(
                self._original_target,
                self._method,
                self._original_callable,  # type: ignore
                value,
                self._allow_coro,
            )
        )
        return self

    def to_return_values(
        self, values_list: List[Any]
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        For each call, return each value from given list in order.
        When list is exhausted, goes to the next behavior set.
        """
        if not isinstance(values_list, list):
            raise ValueError("{} is not a list".format(values_list))
        self._add_runner(
            _ReturnValuesRunner(
                self._original_target,
                self._method,
                self._original_callable,  # type: ignore
                values_list,
                self._allow_coro,
            )
        )
        return self

    def to_yield_values(
        self, values_list: List[Any]
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Callable will return an iterator what will yield each value from the
        given list.
        """
        if not isinstance(values_list, list):
            raise ValueError("{} is not a list".format(values_list))
        self._add_runner(
            _YieldValuesRunner(
                self._original_target,
                self._method,
                self._original_callable,  # type: ignore
                values_list,
                self._allow_coro,
            )
        )
        return self

    def to_raise(
        self, ex: Union[Type[BaseException], BaseException]
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Raises given exception class or exception instance.
        """
        if isinstance(ex, BaseException):
            self._add_runner(
                _RaiseRunner(
                    self._original_target, self._method, self._original_callable, ex  # type: ignore
                )
            )
        elif isinstance(ex(), BaseException):
            self._add_runner(
                _RaiseRunner(
                    self._original_target, self._method, self._original_callable, ex()  # type: ignore
                )
            )
        else:
            raise ValueError(
                "{} is not subclass or instance of BaseException".format(ex)
            )
        return self

    def with_implementation(
        self, func: Callable
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Replace callable by given function.
        """
        if not callable(func):
            raise ValueError("{} must be callable.".format(func))
        self._add_runner(
            _ImplementationRunner(
                self._original_target,
                self._method,
                self._original_callable,  # type: ignore
                func,
                self._allow_coro,
            )
        )
        return self

    def with_wrapper(
        self, func: Callable
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Replace callable with given wrapper function, that will be called as:

          func(original_func, *args, **kwargs)

        receiving the original function as the first argument as well as any given
        arguments.
        """
        if not callable(func):
            raise ValueError("{} must be callable.".format(func))

        if not self._original_callable:
            raise ValueError("Can not wrap original callable that does not exist.")

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            return func(self._original_callable, *args, **kwargs)

        self._add_runner(
            _ImplementationRunner(
                self._original_target,
                self._method,
                self._original_callable,
                wrapper,
                self._allow_coro,
            )
        )
        return self

    def to_call_original(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Calls the original callable implementation, instead of mocking it. This is
        useful for example, if you want to by default call the original implementation,
        but for a specific calls, mock the result.
        """
        if not self._original_callable:
            raise ValueError("Can not call original callable that does not exist.")
        self._add_runner(
            _CallOriginalRunner(
                self._original_target, self._method, self._original_callable
            )
        )
        return self

    ##
    ## Call Assertions
    ##

    def and_assert_called_exactly(
        self, count: int
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Assert that there were exactly the given number of calls.

        If assertion is for 0 calls, any received call will raise
        UnexpectedCallReceived and also an AssertionError.
        """
        if count:
            self._assert_runner()
        else:
            if not self._runner:
                self.to_raise(
                    UnexpectedCallReceived(
                        ("{}, {}: Expected not to be called!").format(
                            _format_target(self._target), repr(self._method)
                        )
                    )
                )
        self._runner.add_exact_calls_assertion(count)  # type: ignore
        return self

    def and_assert_not_called(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Short for and_assert_called_exactly(0)
        """
        return self.and_assert_called_exactly(0)

    def and_assert_called_once(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Short for and_assert_called_exactly(1)
        """
        return self.and_assert_called_exactly(1)

    def and_assert_called_twice(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Short for and_assert_called_exactly(2)
        """
        return self.and_assert_called_exactly(2)

    def and_assert_called_at_least(
        self, count: int
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Assert that there at least the given number of calls.
        """
        if count < 1:
            raise ValueError("times must be >= 1")
        self._assert_runner()
        self._runner.add_at_least_calls_assertion(count)  # type: ignore
        return self

    def and_assert_called_at_most(
        self, count: int
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Assert that there at most the given number of calls.
        """
        if count < 1:
            raise ValueError("times must be >= 1")
        self._assert_runner()
        self._runner.add_at_most_calls_assertion(count)  # type: ignore
        return self

    def and_assert_called(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Short for self.and_assert_called_at_least(1).
        """
        return self.and_assert_called_at_least(1)

    def and_assert_called_ordered(
        self,
    ) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
        """
        Assert that multiple calls, potentially to different mock_callable()
        targets, happened in the order defined.
        """
        self._assert_runner()
        self._runner.add_call_order_assertion()  # type: ignore
        return self


class _MockAsyncCallableDSL(_MockCallableDSL):

    _NAME: str = "mock_async_callable"

    def __init__(
        self,
        target: Union[str, type],
        method: str,
        caller_frame_info: Traceback,
        callable_returns_coroutine: bool,
        allow_private: bool = False,
        type_validation: bool = True,
    ) -> None:
        self._callable_returns_coroutine = callable_returns_coroutine
        super().__init__(
            target,
            method,
            caller_frame_info,
            allow_private=allow_private,
            type_validation=type_validation,
        )
        self._allow_coro = True

    def _validate_patch(self) -> None:  # type: ignore
        return super()._validate_patch(
            name=self._NAME,
            other_name="mock_callable",
            coroutine_function=True,
            callable_returns_coroutine=self._callable_returns_coroutine,
        )

    def _get_callable_mock(self) -> _CallableMock:
        return _CallableMock(
            self._original_target,
            self._method,
            self.caller_frame_info,
            is_async=True,
            callable_returns_coroutine=self._callable_returns_coroutine,
            type_validation=self.type_validation,
        )

    def with_implementation(self, func: Callable) -> "_MockAsyncCallableDSL":
        """
        Replace callable by given async function.
        """
        if not callable(func):
            raise ValueError("{} must be callable.".format(func))
        self._add_runner(
            _AsyncImplementationRunner(
                self._original_target, self._method, self._original_callable, func  # type: ignore
            )
        )
        return self

    def with_wrapper(self, func: Callable) -> "_MockAsyncCallableDSL":
        """
        Replace callable with given wrapper async function, that will be called as:

          await func(original_async_func, *args, **kwargs)

        receiving the original function as the first argument as well as any given
        arguments.
        """
        if not callable(func):
            raise ValueError("{} must be callable.".format(func))

        if not self._original_callable:
            raise ValueError("Can not wrap original callable that does not exist.")

        @functools.wraps(func)
        async def wrapper(*args: Any, **kwargs: Any) -> Any:
            coro = func(self._original_callable, *args, **kwargs)
            if not _is_coroutine(coro):
                raise NotACoroutine(
                    f"Function did not return a coroutine.\n"
                    f"{func} must return a coroutine."
                )
            return await coro

        self._add_runner(
            _AsyncImplementationRunner(
                self._original_target, self._method, self._original_callable, wrapper
            )
        )
        return self

    def to_call_original(self) -> "_MockAsyncCallableDSL":
        """
        Calls the original callable implementation, instead of mocking it. This is
        useful for example, if you want to by default call the original implementation,
        but for a specific calls, mock the result.
        """
        if not self._original_callable:
            raise ValueError("Can not call original callable that does not exist.")
        self._add_runner(
            _AsyncCallOriginalRunner(
                self._original_target, self._method, self._original_callable
            )
        )
        return self
