optimum/neuron/utils/deprecate_utils.py (49 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions to handle deprecation.""" import functools import inspect import warnings from typing import Callable, Dict from packaging import version from ..version import __version__ from .version_utils import ( get_neuroncc_version, get_neuronx_distributed_version, get_neuronxcc_version, get_torch_version, get_torch_xla_version, ) def get_transformers_version() -> str: import transformers return transformers.__version__ PACKAGE_NAME_TO_GET_VERSION_FUNCTION: Dict[str, Callable[[], str]] = { "transformers": get_transformers_version, "optimum-neuron": lambda: __version__, "neuroncc": get_neuroncc_version, "neuronxcc": get_neuronxcc_version, "torch": get_torch_version, "torch_xla": get_torch_xla_version, "neuronx_distributed": get_neuronx_distributed_version, } def deprecate(deprecate_version: str, package_name: str = "optimum-neuron", reason: str = ""): if package_name not in PACKAGE_NAME_TO_GET_VERSION_FUNCTION: raise ValueError(f"Do not know how to retrieve the version of the package called {package_name}.") deprecate_version = version.parse(deprecate_version) try: package_version = PACKAGE_NAME_TO_GET_VERSION_FUNCTION[package_name]() except ModuleNotFoundError: # We do not want to fail if the package is not available, otherwise it will make developping locally impossible. package_version = "0.0.0" package_version = version.parse(package_version) def deprecator(func): @functools.wraps(func) def wrapper(*args, **kwargs): if package_version >= deprecate_version: msg = [f"{func.__name__} is deprecated."] if reason: msg.append(f"Reason: {reason}") msg = " ".join(msg) warnings.warn(msg, category=DeprecationWarning) if inspect.isgeneratorfunction(func): yield from func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper return deprecator