# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utilities"""

import inspect
from typing import TYPE_CHECKING, Type, Union

import torch
import transformers
from accelerate import skip_first_batches as accelerate_skip_first_batches
from transformers import GenerationMixin
from transformers.utils.logging import set_verbosity as set_verbosity_transformers

from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin
from . import is_neuronx_distributed_available
from .patching import replace_class_in_inheritance_hierarchy
from .require_utils import requires_neuronx_distributed, requires_torch_xla


if is_neuronx_distributed_available():
    from neuronx_distributed.pipeline import NxDPPModel


if TYPE_CHECKING:
    from transformers import PreTrainedModel


@requires_torch_xla
def is_topology_supported() -> bool:
    import torch_xla.runtime as xr

    num_devices = xr.world_size()
    allowed_number_of_devices = [1, 2, 8]
    return num_devices in allowed_number_of_devices or num_devices % 32 == 0


def patch_generation_mixin_to_neuron_generation_mixin(
    model: "PreTrainedModel", neuron_generation_mixin_cls: Type = NeuronGenerationMixin
):
    """
    Changes the vanilla `GenerationMixin` class from Transformers to `neuron_generation_mixin_cls` in the model's
    inheritance. This allows to make the model Neuron-compatible for generation without much hassle.
    """
    return replace_class_in_inheritance_hierarchy(model, GenerationMixin, neuron_generation_mixin_cls)


def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrainedModel"):
    """
    Changes the vanilla `GenerationMixin` class from Transformers to `GeneralNeuronGenerationMixin` in the model's
    inheritance. This allows to make the model Neuron-compatible for generation without much hassle.
    """
    return patch_generation_mixin_to_neuron_generation_mixin(
        model, neuron_generation_mixin_cls=GeneralNeuronGenerationMixin
    )


def set_verbosity(verbosity: int):
    set_verbosity_transformers(verbosity)
    set_verbosity_optimum(verbosity)


def patch_transformers_for_neuron_sdk():
    """
    Patches the Transformers library if needed to make it work with AWS Neuron.
    """
    transformers.utils.logging.set_verbosity = set_verbosity


@requires_torch_xla
def skip_first_batches(dataloader, num_batches=0):
    """
    Wrapper around `accelerate.data_loader.skip_first_batches` to handle `pl.ParallelLoader` when using
    `torch_xla.distributed`, for XLA FSDP for instance.
    """
    import torch_xla.distributed.parallel_loader as pl

    if isinstance(dataloader, (pl.ParallelLoader, pl.PerDeviceLoader, pl.MpDeviceLoader)):
        dataloader._loader = skip_first_batches(dataloader._loader, num_batches=num_batches)
    else:
        dataloader = accelerate_skip_first_batches(dataloader, num_batches=num_batches)
    return dataloader


@requires_neuronx_distributed
def _get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"]):
    """Counts the number of parameters of the model."""
    import torch_xla.core.xla_model as xm
    from neuronx_distributed.parallel_layers.parallel_state import (
        get_pipeline_model_parallel_group,
        get_pipeline_model_parallel_rank,
        get_pipeline_model_parallel_size,
        get_tensor_model_parallel_size,
        model_parallel_is_initialized,
    )
    from neuronx_distributed.pipeline import NxDPPModel
    from neuronx_distributed.pipeline.partition import analyze_shared_weights_across_stages

    if isinstance(model, NxDPPModel):
        named_parameters = model.local_named_parameters()
        shared = analyze_shared_weights_across_stages(model.traced_model, model.partitions)
        shared_parameters_across_pipeline_stages = {
            t[0]: t[1] for shared_parameter_info in shared for t in shared_parameter_info
        }
    else:
        named_parameters = model.named_parameters()
        shared_parameters_across_pipeline_stages = {}

    # We make sure `named_parameters` is not an iterator because we are going to iterate over it twice.
    named_parameters = list(named_parameters)

    if torch.distributed.is_initialized() and model_parallel_is_initialized():
        tp_size = get_tensor_model_parallel_size()
        pp_size = get_pipeline_model_parallel_size()
        pp_rank = get_pipeline_model_parallel_rank()
    else:
        tp_size = 1
        pp_size = 1
        pp_rank = 0

    def numel(parameter_name, parameter) -> int:
        should_count_param = shared_parameters_across_pipeline_stages.get(parameter_name, pp_rank) == pp_rank

        num_elements = parameter.numel()
        if getattr(parameter, "tensor_model_parallel", False):
            num_elements *= tp_size

        if parameter.__class__.__name__ == "Params4bit":
            if hasattr(parameter, "element_size"):
                num_bytes = parameter.element_size()
            elif not hasattr(parameter, "quant_storage"):
                num_bytes = 1
            else:
                num_bytes = parameter.quant_storage.itemsize
            num_elements = num_elements * 2 * num_bytes

        return num_elements if should_count_param else 0

    def reduce_param_count_over_pp_ranks(param_count: int):
        param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device())
        param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
        xm.mark_step()
        param_count = int(param_count.detach().cpu().item())
        return param_count

    all_param_count = sum(numel(n, p) for n, p in named_parameters)
    trainable_param_count = sum(numel(n, p) for n, p in named_parameters if p.requires_grad)
    if pp_size > 1:
        all_param_count = reduce_param_count_over_pp_ranks(all_param_count)
        trainable_param_count = reduce_param_count_over_pp_ranks(trainable_param_count)

    return trainable_param_count, all_param_count


@requires_neuronx_distributed
def get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"], trainable_only: bool = False) -> int:
    trainable_param_count, all_param_count = _get_model_param_count(model)
    if trainable_only:
        output = trainable_param_count
    else:
        output = all_param_count
    return output


@requires_neuronx_distributed
def is_main_worker_for_metrics() -> bool:
    from neuronx_distributed.parallel_layers.parallel_state import (
        get_data_parallel_rank,
        get_pipeline_model_parallel_rank,
        get_pipeline_model_parallel_size,
        get_tensor_model_parallel_rank,
    )

    if not torch.distributed.is_initialized():
        return True

    dp_rank = get_data_parallel_rank()
    tp_rank = get_tensor_model_parallel_rank()
    pp_rank = get_pipeline_model_parallel_rank()
    pp_size = get_pipeline_model_parallel_size()

    can_log_loss = dp_rank == tp_rank == 0 and pp_rank == pp_size - 1

    return can_log_loss


def is_main_worker_for_metrics_method(self) -> bool:
    """
    Method version of `is_main_worker_for_metrics`, useful when this is used to patch a method from the Trainer class.
    """
    return is_main_worker_for_metrics()


def is_custom_modeling_model(model) -> bool:
    from peft import PeftModel

    model_to_consider = model
    if isinstance(model, PeftModel):
        model_to_consider = model.get_base_model()
    return inspect.getmodule(model_to_consider.__class__).__name__.startswith("optimum.neuron.models.training")
