optimum/neuron/utils/patching.py (155 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to patching."""
import functools
import importlib
import inspect
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel
class BasePatcher(ABC):
"""
Base abstract class providing the core features for efficient context manager based patching.
"""
def __init__(
self, patching_specs: Optional[List[Tuple[Any, ...]]] = None, ignore_missing_attributes: bool = False
):
self.patching_specs = self.process_patching_specs(
patching_specs, ignore_missing_attributes=ignore_missing_attributes
)
self.already_patched = False
@abstractmethod
def process_patching_specs(
self, patching_specs: Optional[List[Tuple[Any, Any]]] = None, ignore_missing_attributes: bool = False
) -> List[Tuple[Any, str, Any, Any, bool]]:
pass
def patch(self):
if self.already_patched:
return
for module, attribute_name, _, patch, _ in self.patching_specs:
setattr(module, attribute_name, patch)
self.already_patched = True
def restore(self):
if not self.already_patched:
return
for module, attribute_name, orig, _, should_delete_attribute_at_restore in self.patching_specs:
if should_delete_attribute_at_restore:
delattr(module, attribute_name)
else:
setattr(module, attribute_name, orig)
self.already_patched = False
def __enter__(self):
return self.patch()
def __exit__(self, exc_type, exc_value, traceback):
return self.restore()
class DynamicPatch:
"""
Wrapper around a patch function.
This can be used when the patch to apply is a function of the attribute it patches.
"""
def __init__(self, patch_function: Callable[[Any], Any]):
self.patch_function = patch_function
def __call__(self, attribute: Any) -> Any:
return self.patch_function(attribute)
class Patcher(BasePatcher):
"""
Context manager that patches attributes of a module under its scope and restores everything after exit.
"""
def process_patching_specs(
self, patching_specs: Optional[List[Tuple[str, Any]]] = None, ignore_missing_attributes: bool = False
):
processed_patching_specs = []
for orig, patch in patching_specs or []:
module_qualified_name, attribute_name = orig.rsplit(".", maxsplit=1)
try:
module = importlib.import_module(module_qualified_name)
except ModuleNotFoundError as e:
module_qualified_name, module_attribute_containing_attribute_name = module_qualified_name.rsplit(
".", maxsplit=1
)
module = importlib.import_module(module_qualified_name)
try:
module = getattr(module, module_attribute_containing_attribute_name)
except AttributeError:
raise e
module_has_attr = hasattr(module, attribute_name)
if module_has_attr:
attribute = getattr(module, attribute_name)
elif ignore_missing_attributes and not isinstance(patch, DynamicPatch):
attribute = None
elif isinstance(patch, DynamicPatch):
raise ValueError("Cannot ignore missing attribute with a DynamicPatch.")
else:
raise AttributeError(
f"Attribute {attribute_name} does not exist in {module}, set `ignore_missing_attributes=True` "
"to allow not failing when an attribute does not exist."
)
if isinstance(patch, DynamicPatch):
patch = patch(attribute)
processed_patching_specs.append((module, attribute_name, attribute, patch, not module_has_attr))
return processed_patching_specs
class ModelPatcher(BasePatcher):
"""
Context manager that patches attributes of a model under its scope and restores everything after exit.
"""
def process_patching_specs(
self,
patching_specs: Optional[List[Tuple["PreTrainedModel", str, Any]]] = None,
ignore_missing_attributes: bool = False,
):
processed_patching_specs = []
for model, attribute_qualified_name, patch in patching_specs or []:
module_names = attribute_qualified_name.split(".")
attribute_name = module_names.pop(-1)
module = model
for name in module_names:
module = getattr(module, name)
module_has_attr = hasattr(module, attribute_name)
if module_has_attr:
attribute = getattr(module, attribute_name)
elif ignore_missing_attributes and not isinstance(patch, DynamicPatch):
attribute = None
elif isinstance(patch, DynamicPatch):
raise ValueError("Cannot ignore missing attribute with a DynamicPatch.")
else:
raise AttributeError(
f"Attribute {attribute_name} does not exist in {module}, set `ignore_missing_attributes=True` "
"to allow not failing when an attribute does not exist."
)
if isinstance(patch, DynamicPatch):
patch = patch(attribute)
if inspect.ismethod(attribute):
patch = patch.__get__(model)
processed_patching_specs.append((module, attribute_name, attribute, patch, not module_has_attr))
return processed_patching_specs
def patch_within_function(
patching_specs: Union[List[Tuple[str, Any]], Tuple[str, Any]], ignore_missing_attributes: bool = False
):
"""
Decorator that patches attributes of a module during the lifetime of the decorated function.
Args:
patching_specs (`Union[List[Tuple[str, Any]], Tuple[str, Any]]`):
The specifications of what to patch.
ignore_missing_attributes (`bool`, defaults to `False`):
Whether or not the patch should fail if the attribute to patch does not exist.
Returns:
`Callable`: A patched version of the function.
"""
if isinstance(patching_specs, tuple) and len(patching_specs) == 2:
patching_specs = [patching_specs]
patcher = Patcher(patching_specs, ignore_missing_attributes=ignore_missing_attributes)
def decorator(func):
is_bound = hasattr(func, "__self__")
@functools.wraps(func.__func__ if is_bound else func)
def wrapper(*args, **kwargs):
with patcher:
if is_bound:
args = args[1:]
return func(*args, **kwargs)
if is_bound:
wrapper = wrapper.__get__(getattr(func, "__self__"))
return wrapper
return decorator
@functools.lru_cache()
def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None):
"""
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
Args:
attribute_name (`str`):
The name of attribute to patch.
patch (`Any`):
The patch for the attribute.
module_name_prefix (`Optional[str]`, defaults to `None`):
If set, only module names starting with this prefix will be considered for patching.
"""
for name, module in dict(sys.modules).items():
if module_name_prefix is not None and not name.startswith(module_name_prefix):
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)
def replace_class_in_inheritance_hierarchy(obj: Any, orig_cls: Type, replacement_cls: Type):
"""
Inspects the inheritance hierarchy of `obj` and replace `orig_cls` by `replacement_cls` if found.
"""
to_visit = [obj.__class__]
should_stop = False
while to_visit and not should_stop:
cls = to_visit.pop(0)
if cls is object:
continue
bases = cls.__bases__
new_bases = []
for base in bases:
to_visit.append(base)
if base == orig_cls:
new_bases.append(replacement_cls)
should_stop = True
elif base == replacement_cls:
should_stop = True
new_bases.append(base)
else:
new_bases.append(base)
cls.__bases__ = tuple(new_bases)