optimum/neuron/utils/deprecate_utils.py (49 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions to handle deprecation."""
import functools
import inspect
import warnings
from typing import Callable, Dict
from packaging import version
from ..version import __version__
from .version_utils import (
get_neuroncc_version,
get_neuronx_distributed_version,
get_neuronxcc_version,
get_torch_version,
get_torch_xla_version,
)
def get_transformers_version() -> str:
import transformers
return transformers.__version__
PACKAGE_NAME_TO_GET_VERSION_FUNCTION: Dict[str, Callable[[], str]] = {
"transformers": get_transformers_version,
"optimum-neuron": lambda: __version__,
"neuroncc": get_neuroncc_version,
"neuronxcc": get_neuronxcc_version,
"torch": get_torch_version,
"torch_xla": get_torch_xla_version,
"neuronx_distributed": get_neuronx_distributed_version,
}
def deprecate(deprecate_version: str, package_name: str = "optimum-neuron", reason: str = ""):
if package_name not in PACKAGE_NAME_TO_GET_VERSION_FUNCTION:
raise ValueError(f"Do not know how to retrieve the version of the package called {package_name}.")
deprecate_version = version.parse(deprecate_version)
try:
package_version = PACKAGE_NAME_TO_GET_VERSION_FUNCTION[package_name]()
except ModuleNotFoundError:
# We do not want to fail if the package is not available, otherwise it will make developping locally impossible.
package_version = "0.0.0"
package_version = version.parse(package_version)
def deprecator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if package_version >= deprecate_version:
msg = [f"{func.__name__} is deprecated."]
if reason:
msg.append(f"Reason: {reason}")
msg = " ".join(msg)
warnings.warn(msg, category=DeprecationWarning)
if inspect.isgeneratorfunction(func):
yield from func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return deprecator