# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neuron compiled model check and export functions."""

import copy
import time
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import PreTrainedModel

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.cache.entries.multi_model import MultiModelCacheEntry
from ...neuron.cache.entries.single_model import SingleModelCacheEntry
from ...neuron.cache.traced import cache_traced_neuron_artifacts
from ...neuron.utils import (
    DiffusersPretrainedConfig,
    convert_neuronx_compiler_args_to_neuron,
    is_neuron_available,
    is_neuronx_available,
    is_neuronx_distributed_available,
    store_compilation_config,
)
from ...neuron.utils.cache_utils import get_model_name_or_path
from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version
from ...utils import (
    is_diffusers_available,
    is_sentence_transformers_available,
    logging,
)


if TYPE_CHECKING:
    from .base import NeuronDefaultConfig

if is_neuron_available():
    import torch.neuron as neuron  # noqa: F811

    NEURON_COMPILER_TYPE = "neuron-cc"
    NEURON_COMPILER_VERSION = get_neuroncc_version()

if is_neuronx_available():
    import torch_neuronx as neuronx  # noqa: F811

    NEURON_COMPILER_TYPE = "neuronx-cc"
    NEURON_COMPILER_VERSION = get_neuronxcc_version()

if is_diffusers_available():
    from diffusers import ModelMixin
    from diffusers.configuration_utils import FrozenDict

if is_sentence_transformers_available():
    from sentence_transformers import SentenceTransformer

if is_neuronx_distributed_available():
    import neuronx_distributed

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def validate_models_outputs(
    models_and_neuron_configs: Dict[
        str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"]
    ],
    neuron_named_outputs: Dict[str, List[str]],
    output_dir: Path,
    atol: Optional[float] = None,
    neuron_files_subpaths: Optional[Dict[str, str]] = None,
):
    """
    Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
    The following method validates the Neuron models exported using the `export_models` method.

    Args:
        models_and_neuron_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`, `torch.nn.Module`], `NeuronDefaultConfig`]]):
            A dictionnary containing the models to export and their corresponding neuron configs.
        neuron_named_outputs (`List[List[str]]`):
            The names of the outputs to check.
        output_dir (`Path`):
            Output directory where the exported Neuron models are stored.
        atol (`Optional[float]`, defaults to `None`):
            The absolute tolerance in terms of outputs difference between the reference and the exported model.
        neuron_files_subpaths (`Optional[List[str]]`, defaults to `None`):
            The relative paths from `output_dir` to the Neuron files to do validation on. The order must be the same as the order of submodels
            in the ordered dict `models_and_neuron_configs`. If None, will use the keys from the `models_and_neuron_configs` as names.

    Raises:
        ValueError: If the outputs shapes or values do not match between the reference and the exported model.
    """
    if len(neuron_named_outputs) != len(models_and_neuron_configs.keys()):
        raise ValueError(
            f"Invalid number of Neuron named outputs. Required {models_and_neuron_configs.keys()}, Provided {neuron_named_outputs.keys()}"
        )

    if neuron_named_outputs is not None and len(neuron_named_outputs) != len(models_and_neuron_configs):
        raise ValueError(
            f"Provided custom names {neuron_files_subpaths} for the validation of {len(models_and_neuron_configs)} models. Please provide the same number of Neuron file names as models to export."
        )

    exceptions = []  # run all validations before raising
    neuron_paths = []
    for i, model_name in enumerate(models_and_neuron_configs.keys()):
        submodel, sub_neuron_config = models_and_neuron_configs[model_name]
        ref_submodel = copy.deepcopy(submodel)
        neuron_model_path = (
            output_dir.joinpath(neuron_files_subpaths[model_name])
            if neuron_files_subpaths is not None
            else output_dir.joinpath(model_name + ".neuron")
        )
        neuron_paths.append(neuron_model_path)
        try:
            logger.info(f"Validating {model_name} model...")
            validate_model_outputs(
                config=sub_neuron_config,
                reference_model=ref_submodel,
                neuron_model_path=neuron_model_path,
                neuron_named_outputs=neuron_named_outputs[model_name],
                atol=atol,
            )
        except Exception as e:
            exceptions.append(f"Validation of {model_name} fails: {e}")

    if len(exceptions) != 0:
        for i, exception in enumerate(exceptions[:-1]):
            logger.error(f"Validation {i} for the model {neuron_paths[i].as_posix()} raised: {exception}")
        raise Exception(exceptions[-1])


def validate_model_outputs(
    config: "NeuronDefaultConfig",
    reference_model: Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"],
    neuron_model_path: Path,
    neuron_named_outputs: List[str],
    atol: Optional[float] = None,
):
    """
    Validates the export by checking that the outputs from both the reference and the exported model match.

    Args:
        config ([`~optimum.neuron.exporter.NeuronDefaultConfig`]:
            The configuration used to export the model.
        reference_model ([`Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"]`]):
            The model used for the export.
        neuron_model_path (`Path`):
            The path to the exported model.
        neuron_named_outputs (`List[str]`):
            The names of the outputs to check.
        atol (`Optional[float]`, defaults to `None`):
            The absolute tolerance in terms of outputs difference between the reference and the exported model.

    Raises:
        ValueError: If the outputs shapes or values do not match between the reference and the exported model.
    """
    if atol is None:
        if isinstance(config.ATOL_FOR_VALIDATION, dict):
            atol = config.ATOL_FOR_VALIDATION[config.task]
        else:
            atol = config.ATOL_FOR_VALIDATION

    input_shapes = {}
    for axis in config.mandatory_axes:
        input_shapes[axis] = getattr(config, axis)
    if config.dynamic_batch_size is True and "batch_size" in input_shapes:
        input_shapes["batch_size"] *= 2

    # Reference outputs
    with torch.no_grad():
        reference_model.eval()
        inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
        ref_inputs = config.unflatten_inputs(inputs)
        if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False):
            reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
        if "SentenceTransformer" in reference_model.__class__.__name__:
            reference_model = config.patch_model_for_export(reference_model, ref_inputs)
            ref_outputs = reference_model(**ref_inputs)
            neuron_inputs = tuple(config.flatten_inputs(inputs).values())
        elif "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
            config._config, "is_encoder_decoder", False
        ):
            # VAE components for stable diffusion or Encoder-Decoder models
            ref_inputs = tuple(ref_inputs.values())
            ref_outputs = reference_model(*ref_inputs)
            neuron_inputs = tuple(inputs.values())
        elif config.CUSTOM_MODEL_WRAPPER is not None:
            ref_inputs = config.flatten_inputs(inputs)
            reference_model = config.patch_model_for_export(reference_model, ref_inputs, device="cpu")
            neuron_inputs = ref_inputs = tuple(ref_inputs.values())
            ref_outputs = reference_model(*ref_inputs)
        else:
            ref_outputs = reference_model(**ref_inputs)
            neuron_inputs = tuple(config.flatten_inputs(inputs).values())

    # Neuron outputs
    neuron_model = torch.jit.load(neuron_model_path)
    neuron_outputs = neuron_model(*neuron_inputs)
    if isinstance(neuron_outputs, dict):
        neuron_outputs = tuple(neuron_outputs.values())
    elif isinstance(neuron_outputs, torch.Tensor):
        neuron_outputs = (neuron_outputs,)

    # Check if we have a subset of the keys into neuron_outputs against ref_outputs
    neuron_output_names_set = set(neuron_named_outputs)
    neuron_output_names_list = sorted(neuron_output_names_set, key=neuron_named_outputs.index)

    if isinstance(ref_outputs, dict):
        ref_output_names_set = set(ref_outputs.keys())
        if not neuron_output_names_set.issubset(ref_output_names_set):
            raise OutputMatchError(
                "Neuron model output names do not match reference model output names.\n"
                f"Reference model output names: {ref_output_names_set}\n"
                f"Neuron model output names: {neuron_output_names_set}\n"
                f"Difference: {neuron_output_names_set.difference(ref_output_names_set)}"
            )
        else:
            neuron_output_names = ", ".join(neuron_output_names_set)
            logger.info(f"\t-[✓] Neuron model output names match reference model ({neuron_output_names})")
    # folowing are cases for diffusers
    elif isinstance(ref_outputs, torch.Tensor):
        ref_outputs = {neuron_named_outputs[0]: ref_outputs}
    elif isinstance(ref_outputs, tuple):
        ref_outputs = dict(zip(neuron_named_outputs, ref_outputs))

    # Check if the number of outputs matches the number of output names
    if len(neuron_output_names_set) != len(neuron_outputs):
        raise OutputMatchError(
            f"The exported Neuron model has {len(neuron_outputs)} outputs while {len(neuron_output_names_set)} are expected."
        )

    # Check the shape and values match
    shape_failures = []
    value_failures = []
    for i, (name, neuron_output) in enumerate(zip(neuron_output_names_list, neuron_outputs)):
        if isinstance(neuron_output, torch.Tensor):
            ref_output = ref_outputs[name] if isinstance(ref_outputs, dict) else ref_outputs[i]
            neuron_output = neuron_output
        elif isinstance(neuron_output, tuple):  # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors;
            ref_output = torch.stack(ref_outputs[name])
            neuron_output = torch.stack(neuron_output)
        elif isinstance(neuron_output, list):
            ref_output = ref_outputs[name]
            neuron_output = neuron_output

        logger.info(f'\t- Validating Neuron Model output "{name}":')

        # Shape
        output_list = (
            neuron_output if isinstance(neuron_output, list) else [neuron_output]
        )  # eg. `down_block_res_samples` of `ControlNet` is a list of tensors.
        ref_output_list = ref_output if isinstance(ref_output, list) else [ref_output]
        for output, ref_output in zip(output_list, ref_output_list):
            if not output.shape == ref_output.shape:
                logger.error(f"\t\t-[x] shape {output.shape} doesn't match {ref_output.shape}")
                shape_failures.append((name, ref_output.shape, output.shape))
            else:
                logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}")

            # Values
            if not torch.allclose(ref_output, output.to(ref_output.dtype), atol=atol):
                max_diff = torch.max(torch.abs(ref_output - output))
                logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})")
                value_failures.append((name, max_diff))
            else:
                logger.info(f"\t\t-[✓] all values close (atol: {atol})")

    if shape_failures:
        msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (neuron)" for t in shape_failures)
        raise ShapeError("Output shapes do not match between reference model and the Neuron exported model:\n{msg}")

    if value_failures:
        msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
        logger.warning(
            "The maximum absolute difference between the output of the reference model and the Neuron "
            f"exported model is not within the set tolerance {atol}:\n{msg}"
        )


def export_models(
    models_and_neuron_configs: Dict[
        str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"]
    ],
    task: str,
    output_dir: Path,
    disable_neuron_cache: Optional[bool] = False,
    compiler_workdir: Optional[Path] = None,
    inline_weights_to_neff: bool = True,
    optlevel: str = "2",
    output_file_names: Optional[Dict[str, str]] = None,
    compiler_kwargs: Optional[Dict[str, Any]] = {},
    model_name_or_path: Optional[str] = None,
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    """
    Exports a Pytorch model with multiple component models to separate files.

    Args:
        models_and_neuron_configs (`Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], `NeuronDefaultConfig`]]):
            A dictionnary containing the models to export and their corresponding neuron configs.
        task (`str`):
            The task for which the models should be exported.
        output_dir (`Path`):
            Output directory to store the exported Neuron models.
        disable_neuron_cache (`Optional[bool]`, defaults to `False`):
            Whether to disable automatic caching of AOT compiled models (not applicable for JIT compilation).
        compiler_workdir (`Optional[Path]`, defaults to `None`):
            The directory to store intermediary outputs of the neuron compiler.
        inline_weights_to_neff (`bool`, defaults to `True`):
            Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
        optlevel (`str`, defaults to `"2"`):
            The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
                1: enables the core performance optimizations in the compiler, while also minimizing compile time.
                2: provides the best balance between model performance and compile time.
                3: may provide additional model execution performance but may incur longer compile times and higher host memory usage during model compilation.
        output_file_names (`Optional[List[str]]`, defaults to `None`):
            The names to use for the exported Neuron files. The order must be the same as the order of submodels in the ordered dict `models_and_neuron_configs`.
            If None, will use the keys from `models_and_neuron_configs` as names.
        compiler_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
            Arguments to pass to the Neuron(x) compiler for exporting Neuron models.
        model_name_or_path (`Optional[str]`, defaults to `None`):
            Path to pretrained model or model identifier from the Hugging Face Hub.
    Returns:
        `Tuple[Dict[str, List[str]], Dict[str, List[str]]]`: A tuple with two dictionaries containing ordered list of the model's inputs, and the named
        outputs from the Neuron configuration.
    """
    all_inputs = {}
    all_outputs = {}
    if compiler_workdir is not None:
        compiler_workdir = Path(compiler_workdir)

    if output_file_names is not None and len(output_file_names) != len(models_and_neuron_configs):
        raise ValueError(
            f"Provided {len(output_file_names)} custom names for the export of {len(models_and_neuron_configs)} models. Please provide the same number of names as models to export."
        )

    failed_models = []
    total_compilation_time = 0
    compile_configs = {}
    for i, model_name in enumerate(models_and_neuron_configs.keys()):
        logger.info(f"***** Compiling {model_name} *****")
        submodel, sub_neuron_config = models_and_neuron_configs[model_name]
        output_file_name = (
            output_file_names[model_name] if output_file_names is not None else Path(model_name + ".neuron")
        )

        output_path = output_dir / output_file_name
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # TODO: Remove after the weights/neff separation compilation of sdxl is patched by a neuron sdk release: https://github.com/aws-neuron/aws-neuron-sdk/issues/859
        if not inline_weights_to_neff and getattr(sub_neuron_config, "is_sdxl", False):
            logger.warning(
                "The compilation of SDXL's unet with the weights/neff separation is broken since the Neuron SDK 2.18 release. `inline_weights_to_neff` will be set to True and the caching will be disabled. If you still want to separate the neff and weights, please downgrade your Neuron setup to the 2.17.1 release."
            )
            inline_weights_to_neff = True

        start_time = time.time()
        neuron_inputs, neuron_outputs = export(
            model_or_path=submodel,
            config=sub_neuron_config,
            output=output_path,
            compiler_workdir=compiler_workdir,
            inline_weights_to_neff=inline_weights_to_neff,
            optlevel=optlevel,
            **compiler_kwargs,
        )
        compilation_time = time.time() - start_time
        total_compilation_time += compilation_time
        logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.")
        all_inputs[model_name] = neuron_inputs
        all_outputs[model_name] = neuron_outputs

        # Add neuron specific configs to model components' original config
        model_config = sub_neuron_config._config
        if is_diffusers_available() and isinstance(model_config, FrozenDict):
            model_config = OrderedDict(model_config)
            model_config = DiffusersPretrainedConfig.from_dict(model_config)

        model_config = store_compilation_config(
            config=model_config,
            input_shapes=sub_neuron_config.input_shapes,
            compiler_kwargs=compiler_kwargs,
            input_names=neuron_inputs,
            output_names=neuron_outputs,
            dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
            tensor_parallel_size=sub_neuron_config.tensor_parallel_size,
            compiler_type=NEURON_COMPILER_TYPE,
            compiler_version=NEURON_COMPILER_VERSION,
            inline_weights_to_neff=inline_weights_to_neff,
            optlevel=optlevel,
            model_type=getattr(sub_neuron_config, "MODEL_TYPE", None),
            task=getattr(sub_neuron_config, "task", None),
            output_attentions=getattr(sub_neuron_config, "output_attentions", False),
            output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
        )
        model_config.save_pretrained(output_path.parent)
        compile_configs[model_name] = model_config

    logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")

    # cache neuronx model
    if not disable_neuron_cache and is_neuronx_available():
        model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
        if len(compile_configs) == 1:
            # FIXME: this is overly complicated just to pass the config
            cache_config = list(compile_configs.values())[0]
            cache_entry = SingleModelCacheEntry(model_id=model_id, task=task, config=cache_config)
        else:
            cache_entry = MultiModelCacheEntry(model_id=model_id, configs=compile_configs)
        cache_traced_neuron_artifacts(neuron_dir=output_dir, cache_entry=cache_entry)

    # remove models failed to export
    for i, model_name in failed_models:
        output_file_names.pop(model_name)
        models_and_neuron_configs.pop(model_name)

    return all_inputs, all_outputs


def export(
    model_or_path: Union["PreTrainedModel", str, Path],
    config: "NeuronDefaultConfig",
    output: Path,
    compiler_workdir: Optional[Path] = None,
    inline_weights_to_neff: bool = True,
    optlevel: str = "2",
    auto_cast: Optional[str] = None,
    auto_cast_type: str = "bf16",
    disable_fast_relayout: bool = False,
    disable_fallback: bool = False,
) -> Tuple[List[str], List[str]]:
    if is_neuron_available():
        return export_neuron(
            model=model_or_path,
            config=config,
            output=output,
            compiler_workdir=compiler_workdir,
            inline_weights_to_neff=inline_weights_to_neff,
            auto_cast=auto_cast,
            auto_cast_type=auto_cast_type,
            disable_fast_relayout=disable_fast_relayout,
            disable_fallback=disable_fallback,
        )
    elif is_neuronx_available():
        return export_neuronx(
            model_or_path=model_or_path,
            config=config,
            output=output,
            compiler_workdir=compiler_workdir,
            inline_weights_to_neff=inline_weights_to_neff,
            optlevel=optlevel,
            auto_cast=auto_cast,
            auto_cast_type=auto_cast_type,
        )
    else:
        raise RuntimeError(
            "Cannot export the model because the neuron(x) compiler is not installed. See https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html."
        )


def export_neuronx(
    model_or_path: Union["PreTrainedModel", str, Path],
    config: "NeuronDefaultConfig",
    output: Path,
    compiler_workdir: Optional[Path] = None,
    inline_weights_to_neff: bool = True,
    optlevel: str = "2",
    auto_cast: Optional[str] = None,
    auto_cast_type: str = "bf16",
) -> Tuple[List[str], List[str]]:
    """
    Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler.

    Args:
        model_or_path (Union["PreTrainedModel", str, Path]):
            The model to export or its location(case when applying the parallelism as the model needs to be loaded with the tracing).
        config ([`~exporter.NeuronDefaultConfig`]):
            The Neuron configuration associated with the exported model.
        output (`Path`):
            Directory to store the exported Neuron model.
        compiler_workdir (`Optional[Path]`, defaults to `None`):
            The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...).
        inline_weights_to_neff (`bool`, defaults to `True`):
            Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
        optlevel (`str`, defaults to `"2"`):
            The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
                1: enables the core performance optimizations in the compiler, while also minimizing compile time.
                2: provides the best balance between model performance and compile time.
                3: may provide additional model execution performance but may incur longer compile times and higher host memory usage during model compilation.
        auto_cast (`Optional[str]`, defaults to `None`):
            Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `None`, `"matmul"` or `"all"`, you should use `None` to disable any auto-casting, use `"matmul"` to cast FP32 matrix multiplication operations, and use `"all"` to cast all FP32 operations.
        auto_cast_type (`str`, defaults to `"bf16"`):
            The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.

    Returns:
        `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
        the Neuron configuration.
    """
    output.parent.mkdir(parents=True, exist_ok=True)
    if isinstance(compiler_workdir, Path):
        compiler_workdir = compiler_workdir.as_posix()

    if hasattr(model_or_path, "config"):
        model_or_path.config.return_dict = True
        model_or_path.config.torchscript = True
    if isinstance(model_or_path, PreTrainedModel):
        model_or_path.eval()

    # Check if we need to override certain configuration item
    if config.values_override is not None:
        logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
        for override_config_key, override_config_value in config.values_override.items():
            logger.info(f"\t- {override_config_key} -> {override_config_value}")
            if isinstance(model_or_path, PreTrainedModel):
                setattr(model_or_path.config, override_config_key, override_config_value)

    # Prepare dummy inputs for tracing
    input_shapes = {}
    for axis in config.mandatory_axes:
        input_shapes[axis] = getattr(config, axis)

    dummy_inputs = config.generate_dummy_inputs(**input_shapes)
    dummy_inputs = config.flatten_inputs(dummy_inputs)
    dummy_inputs_tuple = tuple(dummy_inputs.values())
    # Prepare the model / function(tp) to trace
    aliases = {}
    tensor_parallel_size = config.tensor_parallel_size
    if getattr(config, "is_encoder_decoder", False):
        checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
        if tensor_parallel_size == 1 and hasattr(config, "generate_io_aliases"):
            aliases = config.generate_io_aliases(checked_model)
    else:
        checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)

    # Construct compiler configurations
    if auto_cast is not None:
        logger.info(f"Using Neuron: --auto-cast {auto_cast}")
        auto_cast = "matmult" if auto_cast == "matmul" else auto_cast
        compiler_args = ["--auto-cast", auto_cast]

        logger.info(f"Using Neuron: --auto-cast-type {auto_cast_type}")
        compiler_args.extend(["--auto-cast-type", auto_cast_type])
    else:
        compiler_args = ["--auto-cast", "none"]

    compiler_args.extend(["--optlevel", optlevel])
    logger.info(f"Using Neuron: --optlevel {optlevel}")

    # no idea what range of models this flag could be applied, here are some exceptions that we have observed so far.
    excluded_models = {
        "unet",
        "vae-encoder",
        "vae-decoder",
        "hubert",
        "levit",
        "mobilenet-v2",
        "mobilevit",
        "unispeech",
        "unispeech-sat",
        "wav2vec2",
        "wavlm",
    }
    if config.MODEL_TYPE not in excluded_models:
        compiler_args.extend(["--model-type", "transformer"])

    compiler_args = add_stable_diffusion_compiler_args(config, compiler_args)  # diffusers specific

    if config.dynamic_batch_size and not inline_weights_to_neff:
        logger.warning(
            "Dynamic batching is not yet compatible with the weights/neff non-inlined model. `inline_weights_to_neff` is set to True. If you still want to separate the neff and weights, please set `dynamic_batch_size=False`."
        )
        inline_weights_to_neff = True

    # Start trace
    if tensor_parallel_size > 1:
        # 1. use NxD to trace for parallel
        neuron_model = neuronx_distributed.trace.parallel_model_trace(
            checked_model,
            dummy_inputs_tuple,
            compiler_args=compiler_args,
            inline_weights_to_neff=inline_weights_to_neff,
            compiler_workdir=compiler_workdir,
            tp_degree=tensor_parallel_size,
        )
        neuronx_distributed.trace.parallel_model_save(neuron_model, output)
    else:
        # 2. use `torch_neuronx.trace`
        neuron_model = neuronx.trace(
            checked_model,
            dummy_inputs_tuple,
            compiler_args=compiler_args,
            input_output_aliases=aliases,
            inline_weights_to_neff=inline_weights_to_neff,
            compiler_workdir=compiler_workdir,
        )
        if config.dynamic_batch_size is True:
            neuron_model = neuronx.dynamic_batch(neuron_model)
        # diffusers specific
        improve_stable_diffusion_loading(config, neuron_model)
        torch.jit.save(neuron_model, output)

    del model_or_path
    del checked_model
    del dummy_inputs
    del neuron_model

    return config.inputs, config.outputs


def add_stable_diffusion_compiler_args(config, compiler_args):
    # Combine the model name and its path to identify which is the subcomponent in Stable Diffusion pipeline
    identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "")
    identifier = identifier.lower()

    sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder", "controlnet"]
    if any(component in identifier for component in sd_components):
        compiler_args.append("--enable-fast-loading-neuron-binaries")
    # unet or transformer or controlnet
    if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]):
        # SDXL unet doesn't support fast loading neuron binaries(sdk 2.19.1)
        if not getattr(config, "is_sdxl", False):
            compiler_args.append("--enable-fast-loading-neuron-binaries")
        if "unet" in identifier or "controlnet" in identifier:
            compiler_args.append("--model-type=unet-inference")
    return compiler_args


def improve_stable_diffusion_loading(config, neuron_model):
    # Combine the model name and its path to identify which is the subcomponent in Diffusion pipeline
    identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "")
    identifier = identifier.lower()
    sd_components = ["text_encoder", "unet", "transformer", "vae", "vae_encoder", "vae_decoder", "controlnet"]
    if any(component in identifier for component in sd_components):
        neuronx.async_load(neuron_model)
    # unet
    if any(model_type in identifier for model_type in ["unet", "transformer", "controlnet"]):
        neuronx.lazy_load(neuron_model)


def export_neuron(
    model: "PreTrainedModel",
    config: "NeuronDefaultConfig",
    output: Path,
    compiler_workdir: Optional[Path] = None,
    inline_weights_to_neff: bool = True,
    auto_cast: Optional[str] = None,
    auto_cast_type: str = "bf16",
    disable_fast_relayout: bool = False,
    disable_fallback: bool = False,
) -> Tuple[List[str], List[str]]:
    """
    Exports a PyTorch model to a serialized TorchScript module compiled by neuron-cc compiler.

    Args:
        model ([`PreTrainedModel`]):
            The model to export.
        config ([`~exporter.NeuronDefaultConfig`]):
            The Neuron configuration associated with the exported model.
        output (`Path`):
            Directory to store the exported Neuron model.
        compiler_workdir (`Optional[Path]`, defaults to `None`):
            The directory used by neuron-cc, where you can find intermediary outputs (neff, weight, hlo...).
        inline_weights_to_neff (`bool`, defaults to `True`):
            Whether to inline the weights to the neff graph. If set to False, weights will be separated from the neff.
        auto_cast (`Optional[str]`, defaults to `None`):
            Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `None`, `"matmul"` or `"all"`, you should use `None` to disable any auto-casting, use `"matmul"` to cast FP32 matrix multiplication operations, and use `"all"` to cast all FP32 operations.
        auto_cast_type (`str`, defaults to `"bf16"`):
            The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"`, ``"mixed" or `"tf32"`. `"mixed"` is only available when auto_cast is "matmul", it will cast operators that use Neuron Matmult engine to bf16 while using fp16 for matmult-based transpose.
        disable_fast_relayout (`bool`, defaults to `False`):
            Whether to disable fast relayout optimization which improves performance by using the matrix multiplier for tensor transpose.
        disable_fallback (`bool`, defaults to `False`):
            Whether to disable CPU partitioning to force operations to Neuron. Defaults to `False`, as without fallback, there could be some compilation failures or performance problems.

    Returns:
        `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
        the Neuron configuration.
    """
    output.parent.mkdir(parents=True, exist_ok=True)
    if isinstance(compiler_workdir, Path):
        compiler_workdir = compiler_workdir.as_posix()

    if hasattr(model, "config"):
        model.config.return_dict = True
        model.config.torchscript = True
    model.eval()

    # Check if we need to override certain configuration item
    if config.values_override is not None:
        logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
        for override_config_key, override_config_value in config.values_override.items():
            logger.info(f"\t- {override_config_key} -> {override_config_value}")
            setattr(model.config, override_config_key, override_config_value)

    input_shapes = {}
    for axis in config.mandatory_axes:
        input_shapes[axis] = getattr(config, axis)

    dummy_inputs = config.generate_dummy_inputs(**input_shapes)
    dummy_inputs_tuple = tuple(dummy_inputs.values())
    checked_model = config.patch_model_for_export(model, dummy_inputs)
    compiler_args = convert_neuronx_compiler_args_to_neuron(auto_cast, auto_cast_type, disable_fast_relayout)

    if config.dynamic_batch_size is True and not inline_weights_to_neff:
        logger.warning(
            "Dynamic batching is not yet compatible with the weights/neff non-inlined model. `inline_weights_to_neff` is set to True. If you still want to separate the neff and weights, please set `dynamic_batch_size=False`."
        )
        inline_weights_to_neff = True

    neuron_model = neuron.trace(
        checked_model,
        dummy_inputs_tuple,
        dynamic_batch_size=config.dynamic_batch_size,
        compiler_args=compiler_args,
        compiler_workdir=compiler_workdir,
        separate_weights=not inline_weights_to_neff,
        fallback=not disable_fallback,
    )
    torch.jit.save(neuron_model, output)
    del model
    del checked_model
    del dummy_inputs
    del neuron_model

    return config.inputs, config.outputs
