# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os

import torch
from accelerate.utils import str_to_bool
from transformers.utils import is_peft_available


if is_peft_available():
    from peft.tuners import lora


has_transformer_engine = False


def import_te():
    global te, has_transformer_engine
    try:
        import habana_frameworks.torch.hpex.experimental.transformer_engine as te

        has_transformer_engine = True

    except ImportError:
        has_transformer_engine = False


def is_fp8_available():
    if not has_transformer_engine:
        import_te()
    return has_transformer_engine


def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
    """
    Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
    """
    from ...transformers.models.llama.modeling_llama import ModuleFusedSDPA

    if not is_fp8_available():
        raise ImportError("Using `convert_model` requires transformer_engine to be installed.")

    minimize_memory = str_to_bool(os.getenv("PT_HPU_FP8_MINIMIZE_MEMORY", "false"))

    for name, module in model.named_children():
        if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear:
            # For lora linear module, convert only base linear layer to fp8 and skip lora-a,
            # lora-b linear layers. Since lora-a, lora-b are small in size, there is not much
            # device performance gain by pushing these in fp8. This way we avoid host overhead
            # associated with using TE for these layers.
            for name, lora_module in module.named_children():
                if name == "base_layer":
                    has_bias = lora_module.bias is not None
                    # Initializing TE linear without weights and biases and shallow copying them from the original module.
                    te_module = te.Linear(
                        lora_module.in_features,
                        lora_module.out_features,
                        bias=has_bias,
                        params_dtype=lora_module.weight.dtype,
                        skip_weight_param_allocation=True,
                        minimize_memory=minimize_memory,
                    )
                    te_module.weight = lora_module.weight

                    if has_bias:
                        te_module.bias = lora_module.bias

                    setattr(module, name, te_module)
        elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            # Initializing TE linear without weights and biases and shallow copying them from the original module.
            te_module = te.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                params_dtype=module.weight.dtype,
                skip_weight_param_allocation=True,
                minimize_memory=minimize_memory,
            )
            te_module.weight = module.weight

            if has_bias:
                te_module.bias = module.bias

            setattr(model, name, te_module)
        elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            new_module = torch.nn.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                dtype=module.weight.dtype,
                device=module.weight.device,
            )
            new_module.weight.copy_(module.weight)
            if has_bias:
                new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
            from habana_frameworks.torch.hpex.experimental.transformer_engine import (
                FusedAttention as TE_FusedAttention,
            )

            class TE_ModuleFusedSDPA(torch.nn.Module):
                def __init__(self):
                    super().__init__()
                    self._hpu_kernel_fsdpa = TE_FusedAttention(
                        scale=module.scale,
                        attention_dropout=module.attention_dropout,
                        enable_recompute=module.enable_recompute,
                    )

                def forward(
                    self,
                    query,
                    key,
                    value,
                    attn_mask,
                    dropout_p,
                    is_causal,
                    scale,
                    softmax_mode,
                    recompute_mode,
                    valid_sequence_lengths,
                    padding_side="left",
                ):
                    return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)

            setattr(model, name, TE_ModuleFusedSDPA())
        else:
            _convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)


def has_transformer_engine_layers(model):
    """
    Returns whether a given model has some `transformer_engine` layer or not.
    """
    if not is_fp8_available():
        raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
    for m in model.modules():
        if isinstance(m, (te.Linear)):
            return True
    return False


def convert_model(model):
    """
    Converts torch.nn.Linear modules to `transformers_engine` Linear modules.
    Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1303
    """
    if not has_transformer_engine_layers(model):
        with torch.no_grad():
            _convert_model(model)
        model._converted_to_transformer_engine = True
    return model


def get_fp8_recipe(fp8_recipe_handler):
    """
    Creates transformer engine FP8 recipe object.
    Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L1309
    """
    if not is_fp8_available():
        raise ImportError("Using `get_fp8_recipe` requires transformer_engine to be installed.")
    kwargs = fp8_recipe_handler.to_dict() if fp8_recipe_handler is not None else {}
    if "fp8_format" in kwargs:
        kwargs["fp8_format"] = getattr(te.recipe.Format, kwargs["fp8_format"])
    fp8_recipe_handler = te.recipe.DelayedScaling(**kwargs)
    fp8_recipe_handler.backend = "TE"
    return fp8_recipe_handler


class FP8ContextWrapper:
    """
    Helper class for FP8 context related operations.
    """

    def __init__(self, ctx, fp8_recipe):
        self.ctx = ctx
        self.fp8_ctx = self.create_fp8_context(fp8_recipe)

    def __enter__(self):
        self.ctx.__enter__()
        self.fp8_ctx.__enter__()

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.fp8_ctx.__exit__(exc_type, exc_value, exc_traceback)
        self.ctx.__exit__(exc_type, exc_value, exc_traceback)

    @staticmethod
    def create_fp8_context(fp8_recipe):
        return te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

    @staticmethod
    def _gradient_checkpointing_wrap(func, *args, **kwargs):
        """
        `_gradient_checkpointing_func` always takes the function to be recomputed as the first argument. The function
        below wraps this first argument with `transformer_engine`'s `activation_checkpointing` context.
        """
        _args = list(args)
        _args[0] = te.distributed.activation_checkpointing()(_args[0])
        args = tuple(_args)

        return func(*args, **kwargs)

    @staticmethod
    def gradient_checkpointing_wrap(model):
        """
        Wrap `_gradient_checkpointing_func` in the model with `transformer_engine`'s `activation_checkpointing` context.
        This context is used to signal the `transformer_engine` modules whether they have been called with activation checkpointing enabled or not.
        """
        if hasattr(model, "gradient_checkpointing") and model.gradient_checkpointing:
            model._gradient_checkpointing_func = functools.partial(
                FP8ContextWrapper._gradient_checkpointing_wrap, model._gradient_checkpointing_func
            )
            return

        for module in model.modules():
            if hasattr(module, "gradient_checkpointing") and module.gradient_checkpointing:
                module._gradient_checkpointing_func = functools.partial(
                    FP8ContextWrapper._gradient_checkpointing_wrap, module._gradient_checkpointing_func
                )
