optimum/neuron/utils/require_utils.py (34 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to handle the several optional requirement packages."""
import functools
from typing import Any, Callable, Dict
from transformers.utils import is_safetensors_available
from .import_utils import (
is_neuronx_distributed_available,
is_peft_available,
is_torch_neuronx_available,
is_torch_xla_available,
)
_AVAILABILITIES: Dict[str, Callable] = {
"safetensors": is_safetensors_available,
"torch_xla": is_torch_xla_available,
"neuronx_distributed": is_neuronx_distributed_available,
"torch_neuronx": is_torch_neuronx_available,
"peft": is_peft_available,
}
def _create_requires_function(package_name: str) -> Callable[..., Any]:
availability_function = _AVAILABILITIES[package_name]
def require_func(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not availability_function():
raise ModuleNotFoundError(
f"{func.__name__} requires the `{package_name}` package. You can install it by running: pip "
f"install {package_name}"
)
return func(*args, **kwargs)
return wrapper
return require_func
requires_safetensors = _create_requires_function("safetensors")
requires_torch_xla = _create_requires_function("torch_xla")
requires_neuronx_distributed = _create_requires_function("neuronx_distributed")
requires_torch_neuronx = _create_requires_function("torch_neuronx")
requires_peft = _create_requires_function("peft")