testslide/patch.py (78 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Any, Callable, Dict, Optional, Union class _DescriptorProxy: def __init__( self, original_class_attr: Optional[Union[Callable, "_DescriptorProxy"]], attr_name: str, ) -> None: self.original_class_attr = original_class_attr self.attr_name = attr_name self.instance_attr_map: Dict[int, Callable] = {} def __set__(self, instance: object, value: Callable) -> None: self.instance_attr_map[id(instance)] = value def __get__( self, instance: object, owner: object ) -> Union[Callable, "_DescriptorProxy"]: if instance is None: return self if id(instance) in self.instance_attr_map: return self.instance_attr_map[id(instance)] if self.original_class_attr: return self.original_class_attr.__get__(instance, owner) # type: ignore for parent in owner.mro()[1:]: # type: ignore method = parent.__dict__.get(self.attr_name, None) if type(method) is type(self): continue if method: return method.__get__(instance, owner) return instance.__get__(instance, owner) # type: ignore def __delete__(self, instance: object) -> None: if instance in self.instance_attr_map: del self.instance_attr_map[id(instance)] def _is_instance_method(target: Any, method: str) -> bool: if inspect.ismodule(target): return False klass = target if inspect.isclass(target) else type(target) for k in klass.mro(): if method in k.__dict__: value = k.__dict__[method] if isinstance(value, _DescriptorProxy): value = value.original_class_attr if inspect.isfunction(value): return True return False def _mock_instance_attribute(instance: Any, attr: str, value: Any) -> Callable: """ Patch attribute at instance with given value. This works for any instance attribute, even when the attribute is defined via the descriptor protocol using __get__ at the class (eg with @property). This allows mocking of the attribute only at the desired instance, as opposed to using Python's unittest.mock.patch.object + PropertyMock, that requires patching at the class level, thus affecting all instances (not only the one you want). """ klass = type(instance) class_restore_value = klass.__dict__.get(attr, None) setattr(klass, attr, _DescriptorProxy(class_restore_value, attr)) setattr(instance, attr, value) def unpatch_class() -> None: if class_restore_value: setattr(klass, attr, class_restore_value) else: delattr(klass, attr) return unpatch_class def _patch( target: Any, attribute: str, new_value: Any, restore: Any, restore_value: Any = None ) -> Callable: if _is_instance_method(target, attribute): unpatcher = _mock_instance_attribute(target, attribute, new_value) elif hasattr(type(target), attribute) and isinstance( getattr(type(target), attribute), property ): original_property = getattr(type(target), attribute) setattr(type(target), attribute, property(fget=lambda _: new_value)) def unpatcher() -> None: if restore_value: setattr(type(target), attribute, original_property) else: delattr(target, attribute) else: setattr(target, attribute, new_value) def unpatcher() -> None: if restore_value: setattr(target, attribute, restore_value) else: delattr(target, attribute) return unpatcher