testslide/mock_constructor.py (303 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import gc
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import testslide
from testslide.mock_callable import _CallableMock, _MockCallableDSL
from .lib import (
_bail_if_private,
_validate_callable_arg_types,
_validate_callable_signature,
)
_DO_NOT_COPY_CLASS_ATTRIBUTES = (
"__dict__",
"__doc__",
"__module__",
"__new__",
"__slots__",
)
_unpatchers: List[Callable] = []
_mocked_target_classes: Dict[Union[int, Tuple[int, str]], Tuple[type, object]] = {}
_restore_dict: Dict[Union[int, Tuple[int, str]], Dict[str, Any]] = {}
_init_args_from_original_callable: Optional[Tuple[Any, ...]] = None
_init_kwargs_from_original_callable: Optional[Dict[str, Any]] = None
_mocked_class_by_original_class_id: Dict[Union[Tuple[int, str], int], type] = {}
_target_class_id_by_original_class_id: Dict[int, Union[Tuple[int, str], int]] = {}
def _get_class_or_mock(original_class: Any) -> Any:
"""
If given class was not a target for mock_constructor, return it.
Otherwise, return the mocked subclass.
"""
return _mocked_class_by_original_class_id.get(id(original_class), original_class)
def _is_mocked_class(klass: Type[object]) -> bool:
return id(klass) in [id(k) for k in _mocked_class_by_original_class_id.values()]
def unpatch_all_constructor_mocks() -> None:
"""
This method must be called after every test unconditionally to remove all
active patches.
"""
try:
for unpatcher in _unpatchers:
unpatcher()
finally:
del _unpatchers[:]
class _MockConstructorDSL(_MockCallableDSL):
"""
Specialized version of _MockCallableDSL to call __new__ with correct args
"""
_NAME: str = "mock_constructor"
def __init__(
self,
target: Union[type, str, object],
method: str,
cls: object,
callable_mock: Union[
Optional[Callable[[Type[object]], Any]], Optional[_CallableMock]
] = None,
original_callable: Optional[Callable] = None,
) -> None:
self.cls = cls
caller_frame = inspect.currentframe().f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
super(_MockConstructorDSL, self).__init__( # type: ignore
target,
method,
caller_frame_info,
callable_mock=callable_mock,
original_callable=original_callable,
)
def for_call(self, *args: Any, **kwargs: Any) -> "_MockConstructorDSL":
return super(_MockConstructorDSL, self).for_call( # type: ignore
*((self.cls,) + args), **kwargs
)
def with_wrapper(self, func: Callable) -> "_MockConstructorDSL":
def new_func(
original_callable: Callable, cls: object, *args: Any, **kwargs: Any
) -> Any:
assert cls == self.cls
def new_original_callable(*args: Any, **kwargs: Any) -> Any:
return original_callable(cls, *args, **kwargs)
return func(new_original_callable, *args, **kwargs)
return super(_MockConstructorDSL, self).with_wrapper(new_func) # type: ignore
def with_implementation(self, func: Callable) -> "_MockConstructorDSL":
def new_func(cls: object, *args: Any, **kwargs: Any) -> Any:
assert cls == self.cls
return func(*args, **kwargs)
return super(_MockConstructorDSL, self).with_implementation(new_func) # type: ignore
def _get_original_init(original_class: type, instance: object, owner: type) -> Any:
target_class_id = _target_class_id_by_original_class_id[id(original_class)]
# If __init__ available at the class __dict__...
if "__init__" in _restore_dict[target_class_id]:
# Use it,
return _restore_dict[target_class_id]["__init__"].__get__(instance, owner)
else:
# otherwise, pull from a parent class.
return original_class.__init__.__get__(instance, owner) # type: ignore
class AttrAccessValidation:
EXCEPTION_MESSAGE = (
"Attribute {} after the class has been used with mock_constructor() "
"is not supported! After using mock_constructor() you must get a "
"pointer to the new mocked class (eg: {}.{})."
)
def __init__(self, name: str, original_class: type, mocked_class: type) -> None:
self.name = name
self.original_class = original_class
self.mocked_class = mocked_class
def __get__(
self, instance: Optional[type], owner: Type[type]
) -> Union[Callable, str]:
mro = owner.mro() # type: ignore
# If owner is a subclass, allow it
if mro.index(owner) < mro.index(self.original_class):
parent_class = mro[mro.index(self.original_class) + 1]
# and return the parent's value
attr = getattr(parent_class, self.name)
if hasattr(attr, "__get__"):
return attr.__get__(instance, parent_class)
else:
return attr
# For class level attributes & methods, we can make it work...
elif instance is None and owner is self.original_class:
# ...by returning the original value from the mocked class
attr = getattr(self.mocked_class, self.name)
if hasattr(attr, "__get__"):
return attr.__get__(instance, self.mocked_class)
else:
return attr
# Disallow for others
else:
raise BaseException(
self.EXCEPTION_MESSAGE.format(
"getting",
self.original_class.__module__,
self.original_class.__name__,
)
)
def __set__(self, instance: object, value: Any) -> None:
raise BaseException(
self.EXCEPTION_MESSAGE.format(
"setting", self.original_class.__module__, self.original_class.__name__
)
)
def __delete__(self, instance: object) -> None:
raise BaseException(
self.EXCEPTION_MESSAGE.format(
"deleting", self.original_class.__module__, self.original_class.__name__
)
)
def _wrap_type_validation(
template: object, callable_mock: _CallableMock, callable_templates: List[Callable]
) -> Callable:
def callable_mock_with_type_validation(*args: Any, **kwargs: Any) -> Any:
for callable_template in callable_templates:
if _validate_callable_signature(
False,
callable_template,
template,
callable_template.__name__,
args,
kwargs,
):
_validate_callable_arg_types(False, callable_template, args, kwargs)
return callable_mock(*args, **kwargs)
return callable_mock_with_type_validation
def _get_mocked_class(
original_class: type,
target_class_id: Union[Tuple[int, str], int],
callable_mock: _CallableMock,
type_validation: bool,
) -> type:
if target_class_id in _target_class_id_by_original_class_id:
raise RuntimeError("Can not mock the same class at two different modules!")
else:
_target_class_id_by_original_class_id[id(original_class)] = target_class_id
original_class_new = original_class.__new__
original_class_init = original_class.__init__ # type: ignore
# Extract class attributes from the target class...
_restore_dict[target_class_id] = {}
class_dict_to_copy = {
name: value
for name, value in original_class.__dict__.items()
if name not in _DO_NOT_COPY_CLASS_ATTRIBUTES
}
for name, value in class_dict_to_copy.items():
try:
delattr(original_class, name)
# Safety net against missing items at _DO_NOT_COPY_CLASS_ATTRIBUTES
except (AttributeError, TypeError):
continue
_restore_dict[target_class_id][name] = value
# ...and reuse them...
mocked_class_dict = {
"__new__": _wrap_type_validation(
original_class,
callable_mock,
[
original_class_new,
original_class_init,
],
)
if type_validation
else callable_mock
}
mocked_class_dict.update(
{
name: value
for name, value in _restore_dict[target_class_id].items()
if name not in ("__new__", "__init__")
}
)
# ...to create the mocked subclass...
mocked_class = type(
str(original_class.__name__), (original_class,), mocked_class_dict
)
# ...and deal with forbidden access to the original class
for name in _restore_dict[target_class_id].keys():
setattr(
original_class,
name,
AttrAccessValidation(name, original_class, mocked_class),
)
# Because __init__ is called after __new__ unconditionally with the same
# arguments, we need to mock it fir this first call, to call the real
# __init__ with the correct arguments.
def init_with_correct_args(self: object, *args: Any, **kwargs: Any) -> None:
global _init_args_from_original_callable, _init_kwargs_from_original_callable
if None not in [
_init_args_from_original_callable,
_init_kwargs_from_original_callable,
]:
args = _init_args_from_original_callable # type: ignore
kwargs = _init_kwargs_from_original_callable # type: ignore
original_init = _get_original_init(
original_class, instance=self, owner=mocked_class
)
try:
original_init(*args, **kwargs)
finally:
_init_args_from_original_callable = None
_init_kwargs_from_original_callable = None
mocked_class.__init__ = init_with_correct_args # type: ignore
return mocked_class
def _patch_and_return_mocked_class(
target: object,
class_name: str,
target_class_id: Union[Tuple[int, str], int],
original_class: type,
callable_mock: _CallableMock,
type_validation: bool,
) -> type:
mocked_class = _get_mocked_class(
original_class, target_class_id, callable_mock, type_validation
)
def unpatcher() -> None:
for name, value in _restore_dict[target_class_id].items():
setattr(original_class, name, value)
del _restore_dict[target_class_id]
setattr(target, class_name, original_class)
del _mocked_target_classes[target_class_id]
del _mocked_class_by_original_class_id[id(original_class)]
del _target_class_id_by_original_class_id[id(original_class)]
_unpatchers.append(unpatcher)
setattr(target, class_name, mocked_class)
_mocked_target_classes[target_class_id] = (original_class, mocked_class)
_mocked_class_by_original_class_id[id(original_class)] = mocked_class
return mocked_class
def mock_constructor(
target: str,
class_name: str,
allow_private: bool = False,
type_validation: bool = True,
) -> _MockConstructorDSL:
if not isinstance(class_name, str):
raise ValueError("Second argument must be a string with the name of the class.")
_bail_if_private(class_name, allow_private)
if isinstance(target, str):
target = testslide._importer(target)
target_class_id = (id(target), class_name)
if target_class_id in _mocked_target_classes:
original_class, mocked_class = _mocked_target_classes[target_class_id]
if not getattr(target, class_name) is mocked_class:
raise AssertionError(
"The class {} at {} was changed after mock_constructor() mocked "
"it!".format(class_name, target)
)
callable_mock = mocked_class.__new__
else:
original_class = getattr(target, class_name)
if "__new__" in original_class.__dict__:
raise NotImplementedError(
"Usage with classes that define __new__() is currently not supported."
)
gc.collect()
instances = [
obj
for obj in gc.get_referrers(original_class)
if type(obj) is original_class
]
if instances:
raise RuntimeError(
"mock_constructor() can not be used after instances of {} were created: {}".format(
class_name, instances
)
)
if not inspect.isclass(original_class):
raise ValueError("Target must be a class.")
elif not issubclass(original_class, object):
raise ValueError("Old style classes are not supported.")
caller_frame = inspect.currentframe().f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
callable_mock = _CallableMock(original_class, "__new__", caller_frame_info)
mocked_class = _patch_and_return_mocked_class(
target,
class_name,
target_class_id,
original_class,
callable_mock,
type_validation,
)
def original_callable(cls: type, *args: Any, **kwargs: Any) -> Any:
global _init_args_from_original_callable, _init_kwargs_from_original_callable
assert cls is mocked_class
# Python unconditionally calls __init__ with the same arguments as
# __new__ once it is invoked. We save the correct arguments here,
# so that __init__ can use them when invoked for the first time.
_init_args_from_original_callable = args
_init_kwargs_from_original_callable = kwargs
return object.__new__(cls)
return _MockConstructorDSL(
target=mocked_class,
method="__new__",
cls=mocked_class,
callable_mock=callable_mock,
original_callable=original_callable,
)