# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific Neuron configurations."""

import copy
from functools import partial
from typing import TYPE_CHECKING, Dict, List

import torch

from optimum.exporters.tasks import TasksManager
from optimum.utils import (
    DummyInputGenerator,
    DummySeq2SeqDecoderTextInputGenerator,
    DummyTextInputGenerator,
    DummyTimestepInputGenerator,
    DummyVisionInputGenerator,
    NormalizedConfig,
    NormalizedConfigManager,
    NormalizedSeq2SeqConfig,
    NormalizedTextAndVisionConfig,
    NormalizedTextConfig,
    NormalizedVisionConfig,
    is_diffusers_available,
)

from ...neuron.utils import (
    ASTDummyAudioInputGenerator,
    DummyBeamValuesGenerator,
    DummyControNetInputGenerator,
    DummyIPAdapterInputGenerator,
    DummyMaskedPosGenerator,
    WhisperDummyTextInputGenerator,
    is_neuronx_distributed_available,
    saved_model_in_temporary_directory,
)
from .config import (
    AudioNeuronConfig,
    TextAndVisionNeuronConfig,
    TextEncoderNeuronConfig,
    TextSeq2SeqNeuronConfig,
    VisionNeuronConfig,
)
from .model_wrappers import (
    CLIPVisionWithProjectionNeuronWrapper,
    ControlNetNeuronWrapper,
    NoCacheModelWrapper,
    PixartTransformerNeuronWrapper,
    SentenceTransformersCLIPNeuronWrapper,
    SentenceTransformersTransformerNeuronWrapper,
    T5DecoderWrapper,
    T5EncoderForSeq2SeqLMWrapper,
    T5EncoderWrapper,
    UnetNeuronWrapper,
    WhisperDecoderWrapper,
    WhisperEncoderWrapper,
)


if is_neuronx_distributed_available():
    import neuronx_distributed

if TYPE_CHECKING:
    if is_diffusers_available():
        from diffusers.models.vae import Decoder as VaeDecoder


COMMON_TEXT_TASKS = [
    "feature-extraction",
    "fill-mask",
    "multiple-choice",
    "question-answering",
    "text-classification",
    "token-classification",
]
register_in_tasks_manager = TasksManager.create_register("neuron")


@register_in_tasks_manager("bert", *COMMON_TEXT_TASKS)
class BertNeuronConfig(TextEncoderNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask", "token_type_ids"]


@register_in_tasks_manager("albert", *COMMON_TEXT_TASKS)
class AlbertNeuronConfig(BertNeuronConfig):
    pass


@register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS)
class ConvBertNeuronConfig(BertNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-1  # TODO: why accuracy more off than other arch

    @property
    def outputs(self) -> List[str]:
        if self.task == "feature-extraction":
            return ["last_hidden_state"]
        return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("electra", *COMMON_TEXT_TASKS)
class ElectraNeuronConfig(BertNeuronConfig):
    @property
    def outputs(self) -> List[str]:
        if self.task == "feature-extraction":
            return ["last_hidden_state"]
        return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"])
class EsmNeuronConfig(TextEncoderNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]


@register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS)
class FlaubertNeuronConfig(ElectraNeuronConfig):
    pass


@register_in_tasks_manager("mobilebert", *COMMON_TEXT_TASKS)
class MobileBertNeuronConfig(BertNeuronConfig):
    pass


@register_in_tasks_manager(
    "modernbert", *["feature-extraction", "fill-mask", "text-classification", "token-classification"]
)
class ModernBertNeuronConfig(BertNeuronConfig):
    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]

    @property
    def outputs(self) -> List[str]:
        if self.task == "feature-extraction":
            return ["last_hidden_state"]
        return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("phi", *["feature-extraction", "text-classification", "token-classification"])
class PhiNeuronConfig(ElectraNeuronConfig):
    CUSTOM_MODEL_WRAPPER = NoCacheModelWrapper

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ElectraNeuronConfig):
    pass


@register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS)
class XLMNeuronConfig(ElectraNeuronConfig):
    pass


@register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS)
class DistilBertNeuronConfig(BertNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]

    @property
    def outputs(self) -> List[str]:
        if self.task == "feature-extraction":
            return ["last_hidden_state"]
        return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS)
class CamembertNeuronConfig(BertNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]


@register_in_tasks_manager("mpnet", *COMMON_TEXT_TASKS)
class MPNetNeuronConfig(CamembertNeuronConfig):
    pass


@register_in_tasks_manager("roberta", *COMMON_TEXT_TASKS)
class RobertaNeuronConfig(CamembertNeuronConfig):
    pass


@register_in_tasks_manager("xlm-roberta", *COMMON_TEXT_TASKS)
class XLMRobertaNeuronConfig(CamembertNeuronConfig):
    pass


# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaNeuronConfig(ElectraNeuronConfig):
    @property
    def inputs(self) -> List[str]:
        common_inputs = super().inputs
        if self._config.type_vocab_size == 0:
            # We remove token type ids.
            common_inputs.pop(-1)
        return common_inputs


# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta-v2", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaV2NeuronConfig(ElectraNeuronConfig):
    pass


@register_in_tasks_manager(
    "transformer", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers"
)
class SentenceTransformersTransformerNeuronConfig(TextEncoderNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    CUSTOM_MODEL_WRAPPER = SentenceTransformersTransformerNeuronWrapper
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]

    @property
    def outputs(self) -> List[str]:
        return ["token_embeddings", "sentence_embedding"]


class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
    TEXT_CONFIG = "text_config"
    VISION_CONFIG = "vision_config"


@register_in_tasks_manager("clip-vision-with-projection", *["feature-extraction"], library_name="diffusers")
class CLIPVisionWithProjectionNeuronConfig(VisionNeuronConfig):
    MODEL_TYPE = "clip-vision-with-projection"
    NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
    CUSTOM_MODEL_WRAPPER = CLIPVisionWithProjectionNeuronWrapper

    @property
    def inputs(self) -> List[str]:
        return ["pixel_values"]

    @property
    def outputs(self) -> List[str]:
        common_outputs = ["image_embeds", "last_hidden_state"]
        if self.output_hidden_states:
            common_outputs.append("hidden_states")
        return common_outputs


@register_in_tasks_manager("clip", *["feature-extraction", "zero-shot-image-classification", "image-classification"])
class CLIPNeuronConfig(TextAndVisionNeuronConfig):
    NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
    INPUT_ARGS = ("text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height")

    @property
    def inputs(self) -> List[str]:
        if self.task == "image-classification":
            return ["pixel_values"]
        else:
            return ["input_ids", "pixel_values", "attention_mask"]

    @property
    def outputs(self) -> List[str]:
        if self.task == "image-classification":
            return ["logits"]
        else:
            return [
                "logits_per_image",
                "logits_per_text",
                "text_embeds",
                "image_embeds",
                "text_model_output",
                "vision_model_output",
            ]

    def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
        for name, axis_dim in self._axes.items():
            self._axes[name] = kwargs.pop(name, axis_dim)

        self._validate_mandatory_axes()

        other_axes = copy.deepcopy(self._axes)
        text_batch_size = other_axes.pop("text_batch_size")
        images_batch_size = other_axes.pop("image_batch_size")

        return [
            DummyTextInputGenerator(self.task, self._normalized_config, batch_size=text_batch_size, **other_axes),
            DummyVisionInputGenerator(self.task, self._normalized_config, batch_size=images_batch_size, **other_axes),
        ]


@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers")
class CLIPTextWithProjectionNeuronConfig(TextEncoderNeuronConfig):
    MODEL_TYPE = "clip-text-with-projection"
    ATOL_FOR_VALIDATION = 1e-3

    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        vocab_size="vocab_size",
        sequence_length="max_position_embeddings",
        num_layers="num_hidden_layers",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["input_ids"]

    @property
    def outputs(self) -> List[str]:
        common_outputs = ["text_embeds", "last_hidden_state"]

        if self._normalized_config.output_hidden_states:
            common_outputs.append("hidden_states")

        return common_outputs


@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers")
class CLIPTextNeuronConfig(CLIPTextWithProjectionNeuronConfig):
    MODEL_TYPE = "clip-text-model"

    @property
    def outputs(self) -> List[str]:
        common_outputs = ["last_hidden_state", "pooler_output"]

        if self._normalized_config.output_hidden_states:
            common_outputs.append("hidden_states")

        return common_outputs


# TODO: We should decouple clip text and vision, this would need fix on Optimum main. For the current workaround
# users can pass dummy text inputs when encoding image, vice versa.
@register_in_tasks_manager(
    "clip", *["feature-extraction", "sentence-similarity"], library_name="sentence_transformers"
)
class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig):
    CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper
    ATOL_FOR_VALIDATION = 1e-3

    @property
    def outputs(self) -> List[str]:
        return ["text_embeds", "image_embeds"]


@register_in_tasks_manager("vit", *["feature-extraction", "image-classification"])
class ViTNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
    DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyMaskedPosGenerator)
    INPUT_ARGS = ("batch_size",)  # `num_channels` and `image_size` are inferred from the config

    @property
    def inputs(self) -> List[str]:
        common_inputs = ["pixel_values"]
        if self.task == "masked-im":
            common_inputs.append("bool_masked_pos")
        return common_inputs


@register_in_tasks_manager("beit", *["feature-extraction", "image-classification"])
class BeitNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("convnext", *["feature-extraction", "image-classification"])
class ConvNextNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("convnextv2", *["feature-extraction", "image-classification"])
class ConvNextV2NeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("cvt", *["feature-extraction", "image-classification"])
class CvTNeuronConfig(ViTNeuronConfig):
    MODEL_TYPE = "cvt"

    @property
    def outputs(self) -> List[str]:
        common_outputs = super().outputs
        if self.task == "feature-extraction":
            return ["last_hidden_state", "cls_token_value"]
        else:
            return common_outputs


@register_in_tasks_manager("deit", *["feature-extraction", "image-classification"])
class DeiTNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("donut-swin", *["feature-extraction"])
class DonutSwinNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("dpt", *["feature-extraction"])
class DptNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("levit", *["feature-extraction", "image-classification"])
class LevitNeuronConfig(ViTNeuronConfig):
    MODEL_TYPE = "levit"
    pass


@register_in_tasks_manager(
    "mobilenet-v2", *["feature-extraction", "image-classification", "semantic-segmentation", "image-segmentation"]
)
class MobileNetV2NeuronConfig(ViTNeuronConfig):
    MODEL_TYPE = "mobilenet-v2"
    pass


@register_in_tasks_manager(
    "mobilevit", *["feature-extraction", "image-classification", "semantic-segmentation", "image-segmentation"]
)
class MobileViTNeuronConfig(ViTNeuronConfig):
    MODEL_TYPE = "mobilevit"
    pass


@register_in_tasks_manager("swin", *["feature-extraction", "image-classification"])
class SwinNeuronConfig(ViTNeuronConfig):
    pass


@register_in_tasks_manager("yolos", *["feature-extraction", "object-detection"])
class YolosTNeuronConfig(ViTNeuronConfig):
    @property
    def outputs(self) -> List[str]:
        common_outputs = super().outputs
        if self.task == "object-detection":
            common_outputs.append("last_hidden_state")
        return common_outputs


@register_in_tasks_manager(
    "wav2vec2",
    *[
        "feature-extraction",
        "automatic-speech-recognition",
        "audio-classification",
        "audio-frame-classification",
        "audio-xvector",
    ],
)
class Wav2Vec2NeuronConfig(AudioNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedConfig
    MODEL_TYPE = "wav2vec2"

    @property
    def inputs(self) -> List[str]:
        return ["input_values"]

    @property
    def outputs(self) -> List[str]:
        common_outputs = super().outputs
        if self.task == "feature-extraction":
            common_outputs = ["last_hidden_state", "extract_features"]
        if self.task == "audio-xvector":
            common_outputs.append("embeddings")
        return common_outputs


@register_in_tasks_manager(
    "audio-spectrogram-transformer",
    *[
        "feature-extraction",
        "audio-classification",
    ],
)
class ASTNeuronConfig(AudioNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True
    )
    DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,)

    @property
    def inputs(self) -> List[str]:
        return ["input_values"]


@register_in_tasks_manager(
    "hubert",
    *[
        "feature-extraction",
        "automatic-speech-recognition",
        "audio-classification",
    ],
)
class HubertNeuronConfig(Wav2Vec2NeuronConfig):
    MODEL_TYPE = "hubert"

    @property
    def outputs(self) -> List[str]:
        common_outputs = super().outputs
        if self.task == "feature-extraction":
            common_outputs = ["last_hidden_state"]
        return common_outputs


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
#     "sew",
#     *[
#         "feature-extraction",
#         "automatic-speech-recognition",
#         "audio-classification",
#     ],
# )
# class SEWNeuronConfig(Wav2Vec2NeuronConfig):
#     pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
#     "sew-d",
#     *[
#         "feature-extraction",
#         "automatic-speech-recognition",
#         "audio-classification",
#     ],
# )
# class SEWDNeuronConfig(Wav2Vec2NeuronConfig):
#     pass


@register_in_tasks_manager(
    "unispeech",
    *[
        "feature-extraction",
        "automatic-speech-recognition",
        "audio-classification",
    ],
)
class UniSpeechNeuronConfig(Wav2Vec2NeuronConfig):
    MODEL_TYPE = "unispeech"
    pass


@register_in_tasks_manager(
    "unispeech-sat",
    *[
        "feature-extraction",
        "automatic-speech-recognition",
        "audio-classification",
        "audio-frame-classification",
        "audio-xvector",
    ],
)
class UniSpeechSATNeuronConfig(Wav2Vec2NeuronConfig):
    MODEL_TYPE = "unispeech-sat"
    pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
#     "wav2vec2-bert",
#     *[
#         "feature-extraction",
#         "automatic-speech-recognition",
#         "audio-classification",
#         "audio-frame-classification",
#         "audio-xvector",
#     ],
# )
# class Wav2Vec2BertNeuronConfig(Wav2Vec2NeuronConfig):
#     pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
#     "wav2vec2-conformer",
#     *[
#         "feature-extraction",
#         "automatic-speech-recognition",
#         "audio-classification",
#         "audio-frame-classification",
#         "audio-xvector",
#     ],
# )
# class Wav2Vec2ConformerNeuronConfig(Wav2Vec2NeuronConfig):
#     pass


@register_in_tasks_manager(
    "wavlm",
    *[
        "feature-extraction",
        "automatic-speech-recognition",
        "audio-classification",
        "audio-frame-classification",
        "audio-xvector",
    ],
)
class WavLMNeuronConfig(Wav2Vec2NeuronConfig):
    MODEL_TYPE = "wavlm"
    pass


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UNetNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height", "vae_scale_factor")
    MODEL_TYPE = "unet"
    CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper
    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        height="height",
        width="width",
        num_channels="in_channels",
        hidden_size="cross_attention_dim",
        vocab_size="norm_num_groups",
        allow_new=True,
    )

    DUMMY_INPUT_GENERATOR_CLASSES = (
        DummyVisionInputGenerator,
        DummyTimestepInputGenerator,
        DummySeq2SeqDecoderTextInputGenerator,
        DummyControNetInputGenerator,
        DummyIPAdapterInputGenerator,
    )

    @property
    def inputs(self) -> List[str]:
        common_inputs = ["sample", "timestep", "encoder_hidden_states"]

        # TODO : add text_image, image and image_embeds
        if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
            common_inputs.append("text_embeds")
            common_inputs.append("time_ids")

        if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
            common_inputs.append("timestep_cond")

        if self.with_controlnet:
            # outputs of controlnet
            common_inputs += ["down_block_additional_residuals", "mid_block_additional_residual"]

        if self.with_ip_adapter:
            # add output of image encoder
            if self.image_encoder_output_hidden_states:
                common_inputs += ["image_enc_hidden_states"]
            else:
                common_inputs += ["image_embeds"]

        return common_inputs

    @property
    def outputs(self) -> List[str]:
        return ["sample"]

    def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
        dummy_inputs = super().generate_dummy_inputs(**kwargs)
        dummy_inputs["timestep"] = dummy_inputs["timestep"].float()
        dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]

        # break down down_block_additional_residuals
        num_down_block_outputs = len(self._normalized_config.down_block_types) * (
            self._normalized_config.layers_per_block + 1
        )
        down_block_additional_residuals = dummy_inputs.pop("down_block_additional_residuals", None)

        if down_block_additional_residuals:
            for idx in range(num_down_block_outputs):
                dummy_inputs[f"down_block_additional_residuals_{idx}"] = down_block_additional_residuals[idx]

        if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
            dummy_inputs["added_cond_kwargs"] = {
                "text_embeds": dummy_inputs.pop("text_embeds"),
                "time_ids": dummy_inputs.pop("time_ids"),
            }

        if return_tuple is True:
            return tuple(dummy_inputs.values())
        else:
            return dummy_inputs

    @property
    def is_sdxl(self) -> bool:
        return self._is_sdxl

    @is_sdxl.setter
    def is_sdxl(self, is_sdxl: bool):
        self._is_sdxl = is_sdxl

    @property
    def with_controlnet(self) -> bool:
        return self._with_controlnet

    @with_controlnet.setter
    def with_controlnet(self, with_controlnet: bool):
        self._with_controlnet = with_controlnet

    @property
    def with_ip_adapter(self) -> bool:
        return self._with_ip_adapter

    @with_ip_adapter.setter
    def with_ip_adapter(self, with_ip_adapter: bool):
        self._with_ip_adapter = with_ip_adapter
        if with_ip_adapter:
            self.mandatory_axes += ("image_encoder_shapes",)
            setattr(self, "image_encoder_shapes", self.input_shapes["image_encoder_shapes"])


@register_in_tasks_manager("pixart-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class PixartTransformerNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    INPUT_ARGS = (
        "batch_size",
        "sequence_length",
        "num_channels",
        "width",
        "height",
        "vae_scale_factor",
        "encoder_hidden_size",
    )
    MODEL_TYPE = "pixart-transformer-2d"
    CUSTOM_MODEL_WRAPPER = PixartTransformerNeuronWrapper
    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        height="height",
        width="width",
        num_channels="in_channels",
        hidden_size="cross_attention_dim",
        vocab_size="norm_num_groups",
        allow_new=True,
    )

    DUMMY_INPUT_GENERATOR_CLASSES = (
        DummyVisionInputGenerator,
        DummyControNetInputGenerator,
        DummyTextInputGenerator,
        DummySeq2SeqDecoderTextInputGenerator,
    )

    @property
    def inputs(self) -> List[str]:
        common_inputs = ["sample", "encoder_hidden_states", "timestep", "encoder_attention_mask"]
        return common_inputs

    @property
    def outputs(self) -> List[str]:
        return ["out_hidden_states"]


@register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers")
class ControlNetNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    INPUT_ARGS = (
        "batch_size",
        "sequence_length",
        "num_channels",
        "height",
        "width",
        "vae_scale_factor",
        "encoder_hidden_size",
    )
    MODEL_TYPE = "controlnet"
    CUSTOM_MODEL_WRAPPER = ControlNetNeuronWrapper
    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        height="height",
        width="width",
        num_channels="in_channels",
        hidden_size="cross_attention_dim",
        vocab_size="norm_num_groups",
        allow_new=True,
    )

    DUMMY_INPUT_GENERATOR_CLASSES = (
        DummyVisionInputGenerator,
        DummyControNetInputGenerator,  # Instead of `encoder_hidden_states` generated by `DummySeq2SeqDecoderTextInputGenerator`
        DummyTimestepInputGenerator,
        DummySeq2SeqDecoderTextInputGenerator,
    )

    @property
    def inputs(self) -> List[str]:
        common_inputs = ["sample", "timestep", "encoder_hidden_states", "controlnet_cond", "conditioning_scale"]

        if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
            common_inputs.append("text_embeds")
            common_inputs.append("time_ids")

        return common_inputs

    @property
    def outputs(self) -> List[str]:
        return ["down_block_res_samples", "mid_block_res_sample"]


@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers")
class VaeEncoderNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    MODEL_TYPE = "vae-encoder"

    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        num_channels="in_channels",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["sample"]

    @property
    def outputs(self) -> List[str]:
        return ["latent_parameters"]

    def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
        dummy_inputs = super().generate_dummy_inputs(**kwargs)

        if return_tuple is True:
            return tuple(dummy_inputs.values())
        else:
            return dummy_inputs


@register_in_tasks_manager("vae-decoder", *["semantic-segmentation"], library_name="diffusers")
class VaeDecoderNeuronConfig(VisionNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    MODEL_TYPE = "vae-decoder"

    NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
        num_channels="latent_channels",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["latent_sample"]

    @property
    def outputs(self) -> List[str]:
        return ["sample"]

    def patch_model_for_export(
        self,
        model: "VaeDecoder",
        dummy_inputs: Dict[str, torch.Tensor],
        **kwargs,
    ):
        return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)


class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
        hidden_size="d_model",
        num_attention_heads="num_heads",
        encoder_num_layers="num_layers",
        decoder_num_layers="num_decoder_layers",
        key_value_dim="d_kv",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]


@register_in_tasks_manager("t5", *["feature-extraction"], library_name="diffusers")
class T5EncoderForDiffusersNeuronConfig(T5EncoderBaseNeuronConfig):
    CUSTOM_MODEL_WRAPPER = T5EncoderWrapper
    INPUT_ARGS = ("batch_size", "sequence_length")

    @property
    def outputs(self) -> List[str]:
        return ["last_hidden_state"]

    @property
    def is_encoder_decoder(self) -> bool:
        return True

    def patch_model_for_export(self, model_or_path, **input_shapes):
        return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)


@register_in_tasks_manager("t5-encoder", *["text2text-generation"])
class T5EncoderForTransformersNeuronConfig(T5EncoderBaseNeuronConfig):
    CUSTOM_MODEL_WRAPPER = T5EncoderForSeq2SeqLMWrapper
    INPUT_ARGS = ("batch_size", "sequence_length", "num_beams")
    MODEL_TYPE = "t5-encoder"

    @property
    def outputs(self) -> List[str]:
        common_outputs = (
            [f"present.{idx}.self.key" for idx in range(self._config.num_decoder_layers)]
            + [f"present.{idx}.self.value" for idx in range(self._config.num_decoder_layers)]
            + [f"present.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)]
            + [f"present.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)]
        )
        return common_outputs

    @property
    def is_encoder_decoder(self) -> bool:
        return True

    def patch_model_for_export(self, model_or_path, device="xla", **kwargs):
        num_beams = kwargs.pop("num_beams", 1)
        sequence_length = kwargs.pop("sequence_length", None)
        batch_size = kwargs.pop("batch_size", None)

        if self.tensor_parallel_size > 1:
            # `torch.nn.modules` objects not eligible for pickling, the model needs to be loaded within the func.
            return partial(
                self.get_parallel_callable,
                model_or_path,
                sequence_length,
                batch_size,
                num_beams,
                device,
                self.tensor_parallel_size,
            )
        else:
            return self.CUSTOM_MODEL_WRAPPER(
                model_or_path,
                sequence_length=sequence_length,
                batch_size=batch_size,
                num_beams=num_beams,
                device=device,
                tensor_parallel_size=self.tensor_parallel_size,
            )

    def get_parallel_callable(
        self, model_name_or_path, sequence_length, batch_size, num_beams, device, tensor_parallel_size
    ):
        """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
        model = TasksManager.get_model_from_task(
            model_name_or_path=model_name_or_path,
            task=self.task,
            framework="pt",
            library_name="transformers",
        )  # TODO: add extra args, eg. revision, trust_remote_code, etc.
        model.config.use_cache = True
        with saved_model_in_temporary_directory(model) as ckpt_path:
            # Plug in parallel layers
            from optimum.neuron.models.inference.t5.modeling_t5 import parallelize

            parallel_model = parallelize(model)
            # Load the weights into the parallel layers
            neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False)
        encoder = self.CUSTOM_MODEL_WRAPPER(
            parallel_model,
            sequence_length=sequence_length,
            batch_size=batch_size,
            num_beams=num_beams,
            device=device,
            tensor_parallel_size=tensor_parallel_size,
        )
        encoder.eval()
        aliases = self.generate_io_aliases(encoder)
        return encoder, aliases

    def generate_io_aliases(self, encoder=None):
        aliases = {}
        if self.tensor_parallel_size > 1:
            for i in range(len(encoder.past_key_values_sa)):
                aliases[encoder.past_key_values_sa[i]] = i
            for i in range(len(encoder.past_key_values_ca)):
                aliases[encoder.past_key_values_ca[i]] = len(encoder.past_key_values_sa) + i
        return aliases


@register_in_tasks_manager("t5-decoder", "text2text-generation")
class T5DecoderNeuronConfig(TextSeq2SeqNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (DummyBeamValuesGenerator,)
    INPUT_ARGS = ("batch_size", "sequence_length", "num_beams")
    MODEL_TYPE = "t5-decoder"
    CUSTOM_MODEL_WRAPPER = T5DecoderWrapper
    NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

    @property
    def inputs(self) -> List[str]:
        common_inputs = [
            "decoder_input_ids",
            "decoder_attention_mask",
            "encoder_hidden_states",
            "attention_mask",  # TODO: replace with `encoder_attention_mask` after optimum 1.14 release
            "beam_idx",
            "beam_scores",
        ]
        return common_inputs

    @property
    def outputs(self) -> List[str]:
        beam_outputs = ["next_token_scores", "next_tokens", "next_indices"] if self.num_beams > 1 else ["next_tokens"]
        common_outputs = (
            beam_outputs
            + [f"past.{idx}.self.key" for idx in range(self._config.num_decoder_layers)]
            + [f"past.{idx}.self.value" for idx in range(self._config.num_decoder_layers)]
            + [f"past.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)]
            + [f"past.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)]
        )

        if self.output_hidden_states:
            # Flatten hidden states of all layers
            common_outputs += [
                f"decoder_hidden_state.{idx}" for idx in range(self._config.num_decoder_layers + 1)
            ]  # +1 for the embedding layer

        if self.output_attentions:
            # Flatten attentions tensors of all attention layers
            common_outputs += [f"decoder_attention.{idx}" for idx in range(self._config.num_decoder_layers)]
            if getattr(self._config, "is_encoder_decoder", False) is True:
                common_outputs += [f"cross_attention.{idx}" for idx in range(self._config.num_decoder_layers)]

        return common_outputs

    @property
    def is_encoder_decoder(self) -> bool:
        return True

    def generate_dummy_inputs(self, **kwargs):
        batch_size = kwargs.pop("batch_size") * kwargs.get("num_beams")
        dummy_inputs = super().generate_dummy_inputs(batch_size=batch_size, **kwargs)
        dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"][:, :1]  # sequence_length = 1
        dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]

        return dummy_inputs

    def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
        dummy_inputs_generators = super()._create_dummy_input_generator_classes(**kwargs)
        dummy_beam_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[-1](
            self.task,
            self._normalized_config,
            num_beams=kwargs.pop("num_beams", 1),
            **kwargs,
        )
        dummy_inputs_generators.append(dummy_beam_values_generator)
        return dummy_inputs_generators

    def patch_model_for_export(self, model, device="xla", **kwargs):
        batch_size = kwargs.pop("batch_size", 1)
        sequence_length = kwargs.pop("sequence_length", 1)
        num_beams = kwargs.pop("num_beams", 1)

        trace_args = {
            "model": model,
            "batch_size": batch_size,
            "sequence_length": sequence_length,
            "num_beams": num_beams,
            "output_hidden_states": self.output_hidden_states,
            "output_attentions": self.output_attentions,
            "device": device,
            "tensor_parallel_size": self.tensor_parallel_size,
        }
        if self.tensor_parallel_size > 1:
            return partial(
                self.get_parallel_callable,
                model,
                batch_size,
                sequence_length,
                num_beams,
                self.output_hidden_states,
                self.output_attentions,
                device,
                self.tensor_parallel_size,
            )
        else:
            return self.CUSTOM_MODEL_WRAPPER(**trace_args)

    def get_parallel_callable(
        self,
        model_name_or_path,
        batch_size,
        sequence_length,
        num_beams,
        output_hidden_states,
        output_attentions,
        device,
        tensor_parallel_size,
    ):
        """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
        model = TasksManager.get_model_from_task(
            model_name_or_path=model_name_or_path,
            task=self.task,
            framework="pt",
            library_name="transformers",
        )  # TODO: add extra args, eg. revision, trust_remote_code, etc.
        model.config.use_cache = True
        with saved_model_in_temporary_directory(model) as ckpt_path:
            # Plug in parallel layers
            from optimum.neuron.models.inference.t5.modeling_t5 import parallelize

            parallel_model = parallelize(model)
            # Load the weights into the parallel layers
            neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False)
        decoder = self.CUSTOM_MODEL_WRAPPER(
            parallel_model,
            batch_size=batch_size,
            sequence_length=sequence_length,
            num_beams=num_beams,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions,
            device=device,
            tensor_parallel_size=tensor_parallel_size,
        )
        decoder.eval()
        aliases = self.generate_io_aliases(decoder)
        return decoder, aliases

    def generate_io_aliases(self, decoder):
        num_outputs_from_trace = 3 if decoder.num_beams > 1 else 1
        aliases = {}
        for i in range(len(decoder.past_key_values_sa)):
            aliases[decoder.past_key_values_sa[i]] = i + num_outputs_from_trace
        for i in range(len(decoder.past_key_values_ca)):
            aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace

        return aliases


@register_in_tasks_manager("whisper-encoder", *["automatic-speech-recognition"])
class WhisperEncoderNeuronConfig(AudioNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    MODEL_TYPE = "whisper-encoder"
    CUSTOM_MODEL_WRAPPER = WhisperEncoderWrapper
    INPUT_ARGS = ("batch_size", "sequence_length")
    DUMMY_INPUT_GENERATOR_CLASSES = AudioNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (WhisperDummyTextInputGenerator,)
    NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
        encoder_num_layers="encoder_layers",
        decoder_num_layers="decoder_layers",
        feature_size="num_mel_bins",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["input_features", "decoder_input_ids"]

    @property
    def outputs(self) -> List[str]:
        return ["lm_logits", "encoder_last_hidden_state"]

    @property
    def is_encoder_decoder(self) -> bool:
        return True

    def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
        kwargs["sequence_length"] = 1  # only `decoder_start_token_id`
        return super().generate_dummy_inputs(return_tuple=return_tuple, **kwargs)

    def patch_model_for_export(self, model_or_path, **input_shapes):
        return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)


@register_in_tasks_manager("whisper-decoder", *["automatic-speech-recognition"])
class WhisperDecoderNeuronConfig(AudioNeuronConfig):
    ATOL_FOR_VALIDATION = 1e-3
    MODEL_TYPE = "whisper-decoder"
    DUMMY_INPUT_GENERATOR_CLASSES = (WhisperDummyTextInputGenerator,)
    INPUT_ARGS = ("batch_size", "sequence_length")
    CUSTOM_MODEL_WRAPPER = WhisperDecoderWrapper
    NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
        encoder_num_layers="encoder_layers",
        decoder_num_layers="decoder_layers",
        feature_size="num_mel_bins",
        hidden_size="d_model",
        allow_new=True,
    )

    @property
    def inputs(self) -> List[str]:
        return ["decoder_input_ids", "encoder_hidden_states"]

    @property
    def outputs(self) -> List[str]:
        return ["lm_logits"]

    @property
    def is_encoder_decoder(self) -> bool:
        return True

    def patch_model_for_export(self, model_or_path, **input_shapes):
        return self.CUSTOM_MODEL_WRAPPER(model_or_path, **input_shapes)
