# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch
import torchao
from packaging.version import parse
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig

from ..integrations import CausalLMExportableModule
from ..task_registry import register_task


# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py.
# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier.
@register_task("text-generation")
def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportableModule:
    """
    Loads a causal language model for text generation and registers it under the task
    'text-generation' using Hugging Face's AutoModelForCausalLM.

    Args:
        model_name_or_path (str):
            Model ID on huggingface.co or path on disk to the model repository to export. For example:
            `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`
        **kwargs:
            Additional configuration options for the model:
                - dtype (str, optional):
                    Data type for model weights (default: "float32").
                    Options include "float16" and "bfloat16".
                - attn_implementation (str, optional):
                    Attention mechanism implementation (default: "sdpa").
                - cache_implementation (str, optional):
                    Cache management strategy (default: "static").
                - max_length (int, optional):
                    Maximum sequence length for generation (default: 2048).

    Returns:
        CausalLMExportableModule:
            An instance of `CausalLMExportableModule` for exporting and lowering to ExecuTorch.
    """
    device = "cpu"
    batch_size = 1
    dtype = kwargs.get("dtype", "float32")
    use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
    use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
    attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
    cache_implementation = kwargs.get("cache_implementation", "static")
    use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
    max_length = kwargs.get("max_length", 2048)
    config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)

    if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
        # NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
        # the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite
        # that function to avoid the data-dependent control flow.
        config.rope_scaling["type"] = "default"

    eager_model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        device_map=device,
        torch_dtype=dtype,
        config=config,
        attn_implementation=attn_implementation,
        generation_config=GenerationConfig(
            use_cache=True,
            cache_implementation=cache_implementation,
            max_length=max_length,
            cache_config={
                "batch_size": batch_size,
                "max_cache_len": max_length,
            },
        ),
    )

    for param in eager_model.parameters():
        # Must disable gradient for quantized checkpoint
        if isinstance(param, torchao.utils.TorchAOBaseTensor):
            param.requires_grad = False

    # TODO: Move quantization recipe out for better composability.
    # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
    qlinear_config = kwargs.get("qlinear", None)
    qembedding_config = kwargs.get("qembedding", None)
    if qlinear_config or qembedding_config:
        # TODO: Update torchao to use 0.11.0 once released
        if parse(torchao.__version__) < parse("0.11.0.dev0"):
            raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")

        from torchao.quantization.granularity import PerAxis, PerGroup
        from torchao.quantization.quant_api import (
            Int8DynamicActivationIntxWeightConfig,
            IntxWeightOnlyConfig,
            quantize_,
        )
        from torchao.utils import unwrap_tensor_subclass

        if qembedding_config:
            logging.info("Quantizing embedding layers.")
            # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
            embedding_config = IntxWeightOnlyConfig(
                weight_dtype=torch.int8,
                granularity=PerAxis(0),
            )
            quantize_(
                eager_model,
                embedding_config,
                lambda m, fqn: isinstance(m, torch.nn.Embedding),
            )

        if qlinear_config:
            logging.info("Quantizing linear layers.")
            linear_config = Int8DynamicActivationIntxWeightConfig(
                weight_dtype=torch.int4,
                weight_granularity=PerGroup(32),
            )
            quantize_(
                eager_model,
                linear_config,
            )

        unwrap_tensor_subclass(eager_model)

    return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)
