optimum/neuron/utils/import_utils.py (52 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities."""
import importlib.util
from typing import Optional
from packaging import version
MIN_ACCELERATE_VERSION = "0.20.1"
MIN_PEFT_VERSION = "0.14.0"
def is_neuron_available() -> bool:
return importlib.util.find_spec("torch_neuron") is not None
def is_neuronx_available() -> bool:
return importlib.util.find_spec("torch_neuronx") is not None
def is_torch_xla_available() -> bool:
found_torch_xla = importlib.util.find_spec("torch_xla") is not None
import_succeeded = True
if found_torch_xla:
try:
pass
except Exception:
import_succeeded = False
return found_torch_xla and import_succeeded
def is_neuronx_distributed_available() -> bool:
return importlib.util.find_spec("neuronx_distributed") is not None
def is_accelerate_available(min_version: Optional[str] = MIN_ACCELERATE_VERSION) -> bool:
_accelerate_available = importlib.util.find_spec("accelerate") is not None
if min_version is not None:
if _accelerate_available:
import accelerate
_accelerate_version = accelerate.__version__
return version.parse(_accelerate_version) >= version.parse(min_version)
else:
return False
return _accelerate_available
def is_torch_neuronx_available() -> bool:
return importlib.util.find_spec("torch_neuronx") is not None
def is_trl_available(required_version: Optional[str] = None) -> bool:
trl_available = importlib.util.find_spec("trl") is not None
if trl_available:
import trl
if required_version is None:
required_version = trl.__version__
if version.parse(trl.__version__) == version.parse(required_version):
return True
raise RuntimeError(f"Only `trl=={required_version}` is supported, but {trl.__version__} is installed.")
return False
def is_peft_available(min_version: Optional[str] = MIN_PEFT_VERSION) -> bool:
_peft_available = importlib.util.find_spec("peft") is not None
if min_version is not None:
if _peft_available:
import peft
_peft_version = peft.__version__
return version.parse(_peft_version) >= version.parse(min_version)
else:
return False
return _peft_available