# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities."""

import importlib.util
from typing import Optional

from packaging import version


MIN_ACCELERATE_VERSION = "0.20.1"
MIN_PEFT_VERSION = "0.14.0"


def is_neuron_available() -> bool:
    return importlib.util.find_spec("torch_neuron") is not None


def is_neuronx_available() -> bool:
    return importlib.util.find_spec("torch_neuronx") is not None


def is_torch_xla_available() -> bool:
    found_torch_xla = importlib.util.find_spec("torch_xla") is not None
    import_succeeded = True
    if found_torch_xla:
        try:
            pass
        except Exception:
            import_succeeded = False
    return found_torch_xla and import_succeeded


def is_neuronx_distributed_available() -> bool:
    return importlib.util.find_spec("neuronx_distributed") is not None


def is_accelerate_available(min_version: Optional[str] = MIN_ACCELERATE_VERSION) -> bool:
    _accelerate_available = importlib.util.find_spec("accelerate") is not None
    if min_version is not None:
        if _accelerate_available:
            import accelerate

            _accelerate_version = accelerate.__version__
            return version.parse(_accelerate_version) >= version.parse(min_version)
        else:
            return False
    return _accelerate_available


def is_torch_neuronx_available() -> bool:
    return importlib.util.find_spec("torch_neuronx") is not None


def is_trl_available(required_version: Optional[str] = None) -> bool:
    trl_available = importlib.util.find_spec("trl") is not None
    if trl_available:
        import trl

        if required_version is None:
            required_version = trl.__version__

        if version.parse(trl.__version__) == version.parse(required_version):
            return True

        raise RuntimeError(f"Only `trl=={required_version}` is supported, but {trl.__version__} is installed.")
    return False


def is_peft_available(min_version: Optional[str] = MIN_PEFT_VERSION) -> bool:
    _peft_available = importlib.util.find_spec("peft") is not None
    if min_version is not None:
        if _peft_available:
            import peft

            _peft_version = peft.__version__
            return version.parse(_peft_version) >= version.parse(min_version)
        else:
            return False
    return _peft_available
