# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NeuroModelForXXX classes for seq2seq models' inference on Neuron devices."""

import copy
import logging
import os
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.utils import ModelOutput

from ..exporters.neuron import (
    NeuronDefaultConfig,
    main_export,
)
from ..exporters.tasks import TasksManager
from ..utils.save_utils import maybe_load_preprocessors
from .generation import NeuronGenerationMixin
from .modeling_traced import NeuronTracedModel
from .utils import (
    DECODER_NAME,
    ENCODER_NAME,
    NEURON_FILE_NAME,
    is_neuronx_available,
    is_neuronx_distributed_available,
)
from .utils.doc import (
    _TOKENIZER_FOR_DOC,
    NEURON_SEQ2SEQ_INPUTS_DOCSTRING,
    NEURON_SEQ2SEQ_MODEL_START_DOCSTRING,
    NEURON_TRANSLATION_EXAMPLE,
    NEURON_TRANSLATION_TP_EXAMPLE,
)


if TYPE_CHECKING:
    from transformers import PretrainedConfig, PreTrainedModel

if is_neuronx_available():
    import torch_neuronx

if is_neuronx_distributed_available():
    import neuronx_distributed

logger = logging.getLogger(__name__)


class _NeuronSeq2SeqModelPart:
    """
    For Seq2Seq architecture, we usually compile it to multiple neuron models. Each represents a part of the model.
    """

    def __init__(
        self,
        model: torch.jit._script.ScriptModule,
        parent_model: NeuronTracedModel,
        config: Optional["PretrainedConfig"] = None,
        neuron_config: Optional["NeuronDefaultConfig"] = None,
        model_type: str = "encoder",
        device: Optional[str] = None,
    ):
        self.model = model
        self.parent_model = parent_model
        self.config = config
        self.neuron_config = neuron_config
        self.model_type = model_type
        self.device = device

    @abstractmethod
    def forward(self, *args, **kwargs):
        pass

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


class NeuronEncoder(_NeuronSeq2SeqModelPart):
    """
    Encoder part of the encoder-decoder model for Neuron inference. (Actually it's a monolith of encoder + decoder without past_key_values to workaround the control flow in the decoder).
    """

    main_input_name = "input_ids"

    def __init__(
        self,
        model: torch.jit._script.ScriptModule,
        parent_model: NeuronTracedModel,
        config: Optional["PretrainedConfig"] = None,
        neuron_config: Optional[Dict[str, str]] = None,
    ):
        super().__init__(model, parent_model, config, neuron_config, "encoder")

    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor):
        inputs = (
            input_ids,
            attention_mask,
        )
        outputs = self.model(*inputs)
        return outputs


class NeuronDecoder(_NeuronSeq2SeqModelPart):
    """
    Decoder part of the encoder-decoder model for Neuron inference. (Actually it's decoder with past_key_values).
    """

    def __init__(
        self,
        model: torch.jit._script.ScriptModule,
        parent_model: NeuronTracedModel,
        config: Optional["PretrainedConfig"] = None,
        neuron_config: Optional[Dict[str, str]] = None,
    ):
        super().__init__(model, parent_model, config, neuron_config, "decoder")

    def forward(
        self,
        input_ids: torch.LongTensor,
        decoder_attention_mask: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor,
        encoder_attention_mask: torch.FloatTensor,
        beam_idx: torch.LongTensor,
        beam_scores: torch.FloatTensor,
    ):
        inputs = (
            input_ids,
            decoder_attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            beam_idx,
            beam_scores,
        )
        outputs = self.model(*inputs)
        return outputs


class NeuronModelForConditionalGeneration(NeuronTracedModel, ABC):
    base_model_prefix = "neuron_model"
    config_name = "config.json"
    encoder_class = NeuronEncoder
    decoder_class = NeuronDecoder

    def __init__(
        self,
        encoder: torch.jit._script.ScriptModule,
        decoder: torch.jit._script.ScriptModule,
        config: "PretrainedConfig",
        model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
        encoder_file_name: Optional[str] = NEURON_FILE_NAME,
        decoder_file_name: Optional[str] = NEURON_FILE_NAME,
        preprocessors: Optional[List] = None,
        neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None,
        configs: Optional[Dict[str, "PretrainedConfig"]] = None,
        generation_config: Optional[GenerationConfig] = None,
        **kwargs,
    ):
        self.config = config
        self.configs = configs
        self.neuron_configs = neuron_configs
        self.input_static_shapes = NeuronModelForConditionalGeneration.get_input_static_shapes(
            self.neuron_configs[ENCODER_NAME]
        )  # only for the encoder
        self._attributes_init(model_save_dir, preprocessors, **kwargs)
        self.encoder = self.encoder_class(
            encoder,
            self,
            self.configs[ENCODER_NAME],
            self.neuron_configs[ENCODER_NAME],
        )
        self.decoder = self.decoder_class(
            decoder,
            self,
            self.configs[DECODER_NAME],
            self.neuron_configs[DECODER_NAME],
        )
        self.dynamic_batch_size = all(
            neuron_config._config.neuron["dynamic_batch_size"] for neuron_config in self.neuron_configs.values()
        )
        self.encoder_file_name = encoder_file_name
        self.decoder_file_name = decoder_file_name

        if generation_config is None:
            generation_config = GenerationConfig.from_model_config(self.configs[DECODER_NAME])
        self.generation_config = generation_config
        self.tensor_parallel_size = self.neuron_configs[DECODER_NAME].tensor_parallel_size

    def _save_pretrained(self, save_directory: Union[str, Path]):
        """
        Saves a model and its configuration file to a directory, so that it can be re-loaded using the
        [`~optimum.neuron.modeling_traced.NeuronTracedModel.from_pretrained`] class method.

        Args:
            save_directory (`Union[str, Path]`):
                Directory where to save the model file.
        """
        shutil.copytree(self.model_save_dir, save_directory, dirs_exist_ok=True)
        self.generation_config.save_pretrained(save_directory)

    @staticmethod
    def load_model(
        encoder_path: Union[str, Path],
        decoder_path: Union[str, Path],
        tensor_parallel_size: int,
    ):
        if tensor_parallel_size == 1:
            # Initialize Neuron Runtime before loading models
            runtime = torch.classes.neuron.Runtime()
            runtime.initialize()
            runtime.set_default_neuron_cores(0, 1)
            encoder = NeuronTracedModel.load_model(encoder_path)
            decoder = NeuronTracedModel.load_model(decoder_path)
            torch_neuronx.move_trace_to_device(decoder, 0)
        else:
            encoder = neuronx_distributed.trace.parallel_model_load(encoder_path)
            decoder = neuronx_distributed.trace.parallel_model_load(decoder_path)
        return encoder, decoder

    @classmethod
    def _from_pretrained(
        cls,
        model_id: Union[str, Path],
        config: "PretrainedConfig",
        token: Optional[Union[bool, str]] = None,
        revision: Optional[str] = None,
        force_download: bool = False,
        cache_dir: Optional[str] = None,
        encoder_file_name: Optional[str] = NEURON_FILE_NAME,
        decoder_file_name: Optional[str] = NEURON_FILE_NAME,
        subfolder: str = "",
        local_files_only: bool = False,
        model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
        **kwargs,
    ):
        model_id = str(model_id)

        if not os.path.isdir(model_id):
            # Downloads all repo's files matching the allowed patterns
            model_id = snapshot_download(
                model_id,
                cache_dir=cache_dir,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                force_download=force_download,
                ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"],  # only download *.neuron artifacts
            )

        preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)

        new_model_save_dir = Path(model_id)

        model_and_config_save_paths = {
            "encoder": (
                new_model_save_dir / ENCODER_NAME / encoder_file_name,
                new_model_save_dir / ENCODER_NAME / cls.config_name,
            ),
            "decoder": (
                new_model_save_dir / DECODER_NAME / decoder_file_name,
                new_model_save_dir / DECODER_NAME / cls.config_name,
            ),
        }

        # Re-build pretrained configs and neuron configs
        configs, neuron_configs = {}, {}
        for name, file_paths in model_and_config_save_paths.items():
            if file_paths[1].is_file():
                model_config = AutoConfig.from_pretrained(file_paths[1])
                configs[name] = model_config
                neuron_configs[name] = cls._neuron_config_init(model_config)

        encoder, decoder = cls.load_model(
            encoder_path=model_and_config_save_paths[ENCODER_NAME][0],
            decoder_path=model_and_config_save_paths[DECODER_NAME][0],
            tensor_parallel_size=configs["decoder"].neuron.get("tensor_parallel_size", 1),
        )

        if model_save_dir is None:
            model_save_dir = new_model_save_dir

        generation_config = None
        try:
            generation_config = GenerationConfig.from_pretrained(
                model_id,
                cache_dir=cache_dir,
                force_download=force_download,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
            )
        except OSError:
            logger.info("Generation config file not found, using a generation config created from the model config.")

        return cls(
            encoder=encoder,
            decoder=decoder,
            config=config,
            model_save_dir=model_save_dir,
            encoder_file_name=encoder_file_name,
            decoder_file_name=decoder_file_name,
            preprocessors=preprocessors,
            neuron_configs=neuron_configs,
            configs=configs,
            generation_config=generation_config,
        )

    @classmethod
    def _from_transformers(cls, *args, **kwargs):
        # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release.
        return cls._export(*args, **kwargs)

    @classmethod
    def _export(
        cls,
        model_id: str,
        config: "PretrainedConfig",
        token: Optional[Union[bool, str]] = None,
        revision: str = "main",
        force_download: bool = True,
        cache_dir: Optional[str] = None,
        compiler_workdir: Optional[str] = None,
        tensor_parallel_size: Optional[int] = 1,
        inline_weights_to_neff: bool = True,
        optlevel: str = "2",
        subfolder: str = "",
        local_files_only: bool = False,
        trust_remote_code: bool = False,
        task: Optional[str] = None,
        auto_cast: Optional[str] = "matmul",
        auto_cast_type: Optional[str] = "bf16",
        disable_fast_relayout: Optional[bool] = False,
        disable_fallback: bool = False,
        dynamic_batch_size: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        **kwargs_shapes,
    ) -> "NeuronModelForConditionalGeneration":
        if dynamic_batch_size is True:
            logger.warning(
                "Sequence-to-sequence models don't support dynamic batch size yet, `dynamic_batch_size` will be set to False."
            )

        if task is None:
            task = TasksManager.infer_task_from_model(cls.auto_model_class)

        # Get compilation arguments
        auto_cast_type = None if auto_cast is None else auto_cast_type
        compiler_kwargs = {
            "auto_cast": auto_cast,
            "auto_cast_type": auto_cast_type,
            "disable_fast_relayout": disable_fast_relayout,
            "disable_fallback": disable_fallback,
        }

        save_dir = TemporaryDirectory()
        save_dir_path = Path(save_dir.name)

        main_export(
            model_name_or_path=model_id,
            output=save_dir_path,
            compiler_kwargs=compiler_kwargs,
            tensor_parallel_size=tensor_parallel_size,
            task=task,
            dynamic_batch_size=dynamic_batch_size,
            cache_dir=cache_dir,
            compiler_workdir=compiler_workdir,
            inline_weights_to_neff=inline_weights_to_neff,
            optlevel=optlevel,
            trust_remote_code=trust_remote_code,
            subfolder=subfolder,
            revision=revision,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            do_validation=False,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            **kwargs_shapes,
        )

        return cls._from_pretrained(
            model_id=save_dir_path,
            config=config,
            model_save_dir=save_dir,
        )

    def _save_config(self, save_directory):
        save_directory = Path(save_directory)
        self.configs[ENCODER_NAME].save_pretrained(save_directory / ENCODER_NAME)
        self.configs[DECODER_NAME].save_pretrained(save_directory / DECODER_NAME)
        combined_config = self._combine_encoder_decoder_config(
            encoder_config=self.configs[ENCODER_NAME],
            decoder_config=self.configs[DECODER_NAME],
        )
        combined_config.save_pretrained(save_directory)

    def _combine_encoder_decoder_config(self, encoder_config: "PretrainedConfig", decoder_config: "PretrainedConfig"):
        encoder_neuron_config = encoder_config.neuron
        decoder_neuron_config = decoder_config.neuron
        combined_config = copy.deepcopy(encoder_config)

        encoder_neuron_config["encoder_input_names"] = encoder_neuron_config.pop("input_names")
        encoder_neuron_config["encoder_output_names"] = encoder_neuron_config.pop("output_names")
        decoder_neuron_config["decoder_input_names"] = decoder_neuron_config.pop("input_names")
        decoder_neuron_config["decoder_output_names"] = decoder_neuron_config.pop("output_names")

        encoder_neuron_config.update(decoder_neuron_config)
        encoder_neuron_config.pop("model_type")
        combined_config.__setattr__("neuron", encoder_neuron_config)

        return combined_config


@add_start_docstrings(
    """
    Neuron Sequence-to-sequence model with a language modeling head for text2text-generation tasks.
    """,
    NEURON_SEQ2SEQ_MODEL_START_DOCSTRING,
)
class NeuronModelForSeq2SeqLM(NeuronModelForConditionalGeneration, NeuronGenerationMixin):
    auto_model_class = AutoModelForSeq2SeqLM
    main_input_name = "input_ids"

    @add_start_docstrings_to_model_forward(
        NEURON_SEQ2SEQ_INPUTS_DOCSTRING.format("batch_size, sequence_length")
        + NEURON_TRANSLATION_EXAMPLE.format(
            processor_class=_TOKENIZER_FOR_DOC,
            model_class="NeuronModelForSeq2SeqLM",
            checkpoint="google-t5/t5-small",
            save_dir="t5_small_neuronx",
        )
        + NEURON_TRANSLATION_TP_EXAMPLE.format(
            processor_class=_TOKENIZER_FOR_DOC,
            model_class="NeuronModelForSeq2SeqLM",
            checkpoint="google/flan-t5-xl",
            save_dir="flan_t5_xl_neuronx_tp8",
        )
    )
    def forward(
        self,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        beam_scores: Optional[torch.FloatTensor] = None,
        return_dict: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
    ) -> Union[Tuple[torch.FloatTensor], ModelOutput]:
        hidden_states = encoder_outputs["last_hidden_state"]

        if not hasattr(self, "beam_idx"):
            # Infering the number of beams from the attention mask
            num_beams = attention_mask.shape[0]
            self.beam_idx = torch.arange(0, num_beams, dtype=torch.int64)

        outputs = self.decoder(
            decoder_input_ids, decoder_attention_mask, hidden_states, attention_mask, self.beam_idx, beam_scores
        )

        # Fetch optional outputs
        cur_idx = 0
        cross_attentions = None
        decoder_attentions = None
        decoder_hidden_states = None

        # Skip pkv which can't be copied from memory to buffer
        if output_attentions and self.configs["decoder"].neuron.get("output_attentions"):
            if self.config.is_encoder_decoder:
                cross_attentions = outputs[-self.config.num_decoder_layers :]
                cur_idx += self.config.num_decoder_layers
            decoder_attentions = outputs[-(self.config.num_decoder_layers + cur_idx) : -cur_idx]
            cur_idx += self.config.num_decoder_layers

        if output_hidden_states and self.configs["decoder"].neuron.get("output_hidden_states"):
            decoder_hidden_states = outputs[-(self.config.num_decoder_layers + 1 + cur_idx) : -cur_idx]

        decoder_outputs = ModelOutput(
            next_token_scores=outputs[0],
            next_tokens=outputs[1],
            next_indices=outputs[2],
            cross_attentions=cross_attentions,
            decoder_attentions=decoder_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )

        if return_dict:
            return decoder_outputs
        else:
            return decoder_outputs.to_tuple()

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        num_return_sequences: int = 1,
        **kwargs,
    ):
        max_length = self.neuron_configs[ENCODER_NAME].sequence_length
        num_beams = self.neuron_configs[ENCODER_NAME].num_beams
        batch_size = self.neuron_configs[ENCODER_NAME].batch_size

        inputs = {"input_ids": input_ids}
        if attention_mask is not None:
            inputs["attention_mask"] = attention_mask
        inputs = self._pad_to_compiled_shape(inputs)

        past_key_values = self.encoder(**inputs)

        decoder_attention_mask = torch.cat(
            [
                torch.zeros((batch_size, max_length - 1), dtype=torch.int64),
                torch.ones((batch_size, 1), dtype=torch.int64),
            ],
            axis=1,
        )

        if self.tensor_parallel_size == 1:
            # copy the new cache state to the decoder
            for state, tensor in zip(self.decoder.model.parameters(), past_key_values):
                state.copy_(tensor)
        else:
            # Here we iterate sharded encoders and decoders since the encoder on each rank will return cache as device tensors,
            # we want to assign them to the cache of the sharded decoder on the same rank to avoid the copy. The KV cache always
            # use pre-allocated memory, no host-device communication overhead.
            for decoder_tp, encoder_tp in zip(self.decoder.model.models, self.encoder.model.models):
                decoder_tp.load_state_dict(encoder_tp.state_dict(), strict=False)

        output = super().generate(
            **inputs,
            generation_config=generation_config,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            assistant_model=assistant_model,
            num_return_sequences=num_return_sequences,
            max_length=kwargs.pop("max_length", None) or max_length,
            max_new_tokens=kwargs.pop("max_new_tokens", None),
            output_attentions=kwargs.pop("output_attentions", False),
            output_hidden_states=kwargs.pop("output_hidden_states", False),
            output_scores=kwargs.pop("output_scores", False),
            return_dict_in_generate=kwargs.pop("return_dict_in_generate", False),
            num_beams=num_beams,
            do_sample=kwargs.pop("do_sample", False),
            use_cache=True,  # pkv is cached by default in
            decoder_attention_mask=decoder_attention_mask,
            # Pass fake encoder_outputs so the transfomers code will not invoke the encoder
            encoder_outputs={"last_hidden_state": torch.ones((batch_size, max_length, 1))},
            is_traced_inference=True,
        )
        return output

    def _reorder_cache(self, beam_idx):
        """
        The cache was reordered during the tracing of the decoder so we can skip it here. This is needed for beam search and not greedy sampling.
        """
        self.beam_idx = beam_idx

    def get_encoder(self) -> "NeuronEncoder":
        return self.encoder

    def _update_model_kwargs_for_xla_generation(
        self,
        model_kwargs: Dict[str, Any],
        batch_size: int,
        is_encoder_decoder: bool = False,
        # Leave following kwargs for compatibility, will not have any effect.
        outputs: "ModelOutput" = None,
        standardize_cache_format: bool = False,
        max_length: Optional[int] = None,
        seq_length: Optional[int] = None,
        use_cache: bool = True,
    ) -> Dict[str, Any]:
        mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder)
        # sets the updated variables (mask and past_key_values)
        model_kwargs.update(mask)

        return model_kwargs

    # Override to cut the input_ids to just last token
    def prepare_inputs_for_generation(
        self,
        input_ids,
        attention_mask=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids as past is cached
        input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
        }

    def _validate_static_shape(self, input_shapes: List[int], target_shapes: List[int]) -> bool:
        """
        Checks if a input needs to be padded.
        """
        return input_shapes == target_shapes

    def can_generate(self):
        """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
        return True
