# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Whisper model on Neuron devices."""

import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import GenerationConfig, WhisperForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.utils import ModelOutput

from ....exporters.neuron import (
    NeuronDefaultConfig,
)
from ...modeling_seq2seq import NeuronModelForConditionalGeneration, _NeuronSeq2SeqModelPart
from ...modeling_traced import NeuronTracedModel
from ...utils import (
    NEURON_FILE_NAME,
    is_neuronx_available,
    is_neuronx_distributed_available,
)
from ...utils.doc import (
    NEURON_AUDIO_SEQ2SEQ_INPUTS_DOCSTRING,
    NEURON_SEQ2SEQ_MODEL_START_DOCSTRING,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
)


if TYPE_CHECKING:
    from transformers import PretrainedConfig

if is_neuronx_available():
    pass

if is_neuronx_distributed_available():
    pass

logger = logging.getLogger(__name__)


class DummyLayer:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __call__(self, x):
        return x


class NeuronWhisperEncoder(_NeuronSeq2SeqModelPart):
    """
    Encoder and the 1st forward of decoder+language head.
    """

    main_input_name = "input_features"

    def __init__(
        self,
        model: torch.jit._script.ScriptModule,
        parent_model: NeuronTracedModel,
        config: Optional["PretrainedConfig"] = None,
        neuron_config: Optional[Dict[str, str]] = None,
    ):
        super().__init__(model, parent_model, config, neuron_config, "encoder")
        stride = getattr(self.config, "stride", [1, 2])
        self.conv1 = DummyLayer(stride=[stride[0]])
        self.conv2 = DummyLayer(stride=[stride[1]])

    def forward(
        self,
        input_features: torch.FloatTensor,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        prepare_encoder_decoder_kwargs_for_generation = False
        if decoder_input_ids is None:
            decoder_input_ids = torch.full(
                (self.neuron_config.batch_size, 1), self.config.decoder_start_token_id, dtype=torch.long
            )
            prepare_encoder_decoder_kwargs_for_generation = True
        outputs = self.model(input_features, decoder_input_ids)
        if prepare_encoder_decoder_kwargs_for_generation:
            return BaseModelOutput(last_hidden_state=outputs[1])
        else:
            return outputs


class NeuronWhisperDecoder(_NeuronSeq2SeqModelPart):
    """
    Decoder with output embedding of the whisper model for Neuron inference.
    """

    def __init__(
        self,
        model: torch.jit._script.ScriptModule,
        parent_model: NeuronTracedModel,
        config: Optional["PretrainedConfig"] = None,
        neuron_config: Optional[Dict[str, str]] = None,
    ):
        super().__init__(model, parent_model, config, neuron_config, "decoder")

    def forward(
        self,
        decoder_input_ids: Optional[torch.LongTensor],
        encoder_hidden_states: Optional[torch.FloatTensor],
        **kwargs,
    ):
        inputs = (
            decoder_input_ids,
            encoder_hidden_states,
        )
        outputs = self.model(*inputs)
        return (outputs, encoder_hidden_states)


class NeuronWhisperModel:
    def __init__(self, encoder: NeuronWhisperEncoder, decoder: NeuronWhisperDecoder):
        self.encoder = encoder
        self.decoder = decoder


@add_start_docstrings(
    """
    Whisper Neuron model with a language modeling head that can be used for automatic speech recognition.
    """,
    NEURON_SEQ2SEQ_MODEL_START_DOCSTRING,
)
class NeuronWhisperForConditionalGeneration(NeuronModelForConditionalGeneration, WhisperForConditionalGeneration):
    auto_model_class = WhisperForConditionalGeneration
    main_input_name = "input_features"
    encoder_class = NeuronWhisperEncoder
    decoder_class = NeuronWhisperDecoder

    def __init__(
        self,
        encoder: torch.jit._script.ScriptModule,
        decoder: torch.jit._script.ScriptModule,
        config: "PretrainedConfig",
        model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
        encoder_file_name: Optional[str] = NEURON_FILE_NAME,
        decoder_file_name: Optional[str] = NEURON_FILE_NAME,
        preprocessors: Optional[List] = None,
        neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None,
        configs: Optional[Dict[str, "PretrainedConfig"]] = None,
        generation_config: Optional[GenerationConfig] = None,
        **kwargs,
    ):
        super().__init__(
            encoder,
            decoder,
            config,
            model_save_dir,
            encoder_file_name,
            decoder_file_name,
            preprocessors,
            neuron_configs,
            configs,
            generation_config,
            **kwargs,
        )
        self.model = NeuronWhisperModel(self.encoder, self.decoder)

    @property
    def device(self):
        return torch.device("cpu")

    def get_encoder(self) -> "NeuronWhisperEncoder":
        return self.encoder

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        is_encoder_decoder: bool = False,
        num_new_tokens: int = 1,
    ) -> Dict[str, Any]:
        # Override "use_cache" to False, since whisper with cache is not yet supported for neuron.
        model_kwargs["use_cache"] = False

        model_kwargs = super()._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder, num_new_tokens
        )

        return model_kwargs

    @add_start_docstrings_to_model_forward(NEURON_AUDIO_SEQ2SEQ_INPUTS_DOCSTRING)
    def forward(
        self,
        input_features: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        **kwargs,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        if encoder_outputs is None:
            lm_logits, encoder_last_hidden_state = self.encoder(
                input_features=input_features, decoder_input_ids=decoder_input_ids
            )
        else:
            # pad `decoder_input_ids` to the sequence length of the compilation
            decoder_input_ids_length = decoder_input_ids.shape[1]
            pad_size = torch.as_tensor(self.neuron_configs["decoder"].sequence_length - decoder_input_ids_length)
            decoder_input_ids = torch.nn.functional.pad(
                decoder_input_ids, (0, pad_size), "constant", self.preprocessors[0].pad_token_id
            )

            lm_logits, encoder_last_hidden_state = self.decoder(
                decoder_input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_outputs[0],
            )
            # unpad
            lm_logits = lm_logits[:, :decoder_input_ids_length, :]

        return Seq2SeqLMOutput(
            logits=lm_logits,
            encoder_last_hidden_state=encoder_last_hidden_state,
        )
