# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the command line for the export with Neuronx compiler."""

import subprocess
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from ...exporters import TasksManager
from ...utils import is_diffusers_available
from ..base import BaseOptimumCLICommand, CommandInfo


if is_diffusers_available():
    # Mandatory for applying optimized attention score of Stable Diffusion
    import os

    os.environ["NEURON_FUSE_SOFTMAX"] = "1"
    os.environ["NEURON_CUSTOM_SILU"] = "1"

if TYPE_CHECKING:
    from argparse import ArgumentParser, Namespace, _SubParsersAction


def parse_args_neuronx(parser: "ArgumentParser"):
    required_group = parser.add_argument_group("Required arguments")
    required_group.add_argument(
        "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
    )
    required_group.add_argument(
        "output",
        type=Path,
        help="Path indicating the directory where to store generated Neuronx compiled TorchScript model.",
    )

    optional_group = parser.add_argument_group("Optional arguments")
    optional_group.add_argument(
        "--task",
        default="auto",
        help=(
            "The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
            f" {str(list(TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.keys()) + list(TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS.keys()))}."
        ),
    )
    optional_group.add_argument(
        "--subfolder",
        type=str,
        default="",
        help=(
            "In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, specify the folder name here."
        ),
    )
    optional_group.add_argument(
        "--atol",
        type=float,
        default=None,
        help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
    )
    optional_group.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="Path to a directory in which a downloaded pretrained PyTorch model weights have been cached.",
    )
    optional_group.add_argument(
        "--disable_neuron_cache",
        action="store_true",
        help="Whether to disable automatic caching of compiled models (not applicable for JIT compilation).",
    )
    optional_group.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="Allow to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository.",
    )
    optional_group.add_argument(
        "--compiler_workdir",
        type=Path,
        help="Path indicating the directory where to store intermediary files generated by Neuronx compiler.",
    )
    optional_group.add_argument(
        "--inline-weights-neff",
        action="store_true",
        help="Whether to inline the weights / neff graph. It is possible to replace weights of neuron-compiled models only when the weights-neff inlining has been disabled during the compilation. So the caching will not work when this option is enabled.",
    )
    optional_group.add_argument(
        "--disable-validation",
        action="store_true",
        help="Whether to disable the validation of inference on neuron device compared to the outputs of original PyTorch model on CPU.",
    )
    optional_group.add_argument(
        "--auto_cast",
        type=str,
        default=None,
        choices=["none", "matmul", "all"],
        help='Whether to cast operations from FP32 to lower precision to speed up the inference. Can be `"none"`, `"matmul"` or `"all"`.',
    )
    optional_group.add_argument(
        "--auto_cast_type",
        type=str,
        default="bf16",
        choices=["bf16", "fp16", "tf32"],
        help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.',
    )
    optional_group.add_argument(
        "--torch_dtype",
        type=str,
        default=None,
        choices=["bfloat16", "float16", "float32"],
        help="Override the default `torch.dtype` and load the model under this dtype. If `None` is passed, the dtype will be automatically derived from the model's weights.",
    )
    optional_group.add_argument(
        "--tensor_parallel_size",
        type=int,
        default=1,
        help="Tensor parallelism size, the number of neuron cores on which to shard the model.",
    )
    optional_group.add_argument(
        "--dynamic-batch-size",
        action="store_true",
        help="Enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms of latency.",
    )
    optional_group.add_argument(
        "--num_cores",
        type=int,
        default=None,
        help="The number of cores on which the model should be deployed (text-generation only).",
    )
    optional_group.add_argument(
        "--unet",
        default=None,
        help=(
            "UNet model ID on huggingface.co or path on disk to load model from. This will replace the unet in the original Stable Diffusion pipeline."
        ),
    )
    optional_group.add_argument(
        "--output_hidden_states",
        action="store_true",
        help=("Whether or not for the traced model to return the hidden states of all layers."),
    )
    optional_group.add_argument(
        "--lora_model_ids",
        default=None,
        nargs="*",
        type=str,
        help=(
            "List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights."
        ),
    )
    optional_group.add_argument(
        "--lora_weight_names",
        default=None,
        nargs="*",
        type=str,
        help="List of lora weights file names.",
    )
    optional_group.add_argument(
        "--lora_adapter_names",
        default=None,
        nargs="*",
        type=str,
        help="List of the adapter names to be used for referencing the loaded adapter models.",
    )
    optional_group.add_argument(
        "--lora_scales",
        default=None,
        nargs="*",
        type=float,
        help="List of scaling factors for the lora adapters.",
    )
    optional_group.add_argument(
        "--output_attentions",
        action="store_true",
        help="Whether or not for the traced model to return the attentions tensors of all attention layers.",
    )

    # Diffusion Only
    optional_group.add_argument(
        "--controlnet_ids",
        default=None,
        nargs="*",
        type=str,
        help="List of model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`) of ControlNet models.",
    )
    ip_adapter_group = parser.add_argument_group("IP adapters")
    ip_adapter_group.add_argument(
        "--ip_adapter_id",
        default=None,
        nargs="*",
        type=str,
        help=(
            "Model ids (eg. `h94/IP-Adapter`) of IP-Adapter models hosted on the Hub or paths to local directories containing the IP-Adapter weights."
        ),
    )
    ip_adapter_group.add_argument(
        "--ip_adapter_subfolder",
        default=None,
        nargs="*",
        type=str,
        help="The subfolder location of a model file within a larger model repository on the Hub or locally. If a list is passed, it should have the same length as `ip_adapter_weight_names`.",
    )
    ip_adapter_group.add_argument(
        "--ip_adapter_weight_name",
        default=None,
        nargs="*",
        type=str,
        help="The name of the weight file to load. If a list is passed, it should have the same length as `ip_adapter_subfolders`.",
    )
    ip_adapter_group.add_argument(
        "--ip_adapter_scale",
        default=None,
        nargs="*",
        type=float,
        help="Scaling factors for the IP-Adapters.",
    )

    # Static Input Shapes
    input_group = parser.add_argument_group("Input shapes")
    doc_input = "that the Neuronx-cc compiler exported model will be able to take as input."
    input_group.add_argument(
        "--batch_size",
        type=int,
        help=f"Batch size {doc_input}",
    )
    input_group.add_argument(
        "--text_batch_size",
        type=int,
        help=f"Batch size of the text inputs {doc_input} (Only applied for multi-modal models)",
    )
    input_group.add_argument(
        "--image_batch_size",
        type=int,
        help=f"Batch size of the vision inputs {doc_input} (Only applied for multi-modal models)",
    )
    input_group.add_argument(
        "--sequence_length",
        type=int,
        help=f"Sequence length {doc_input}",
    )
    input_group.add_argument(
        "--num_beams",
        type=int,
        help=f"Number of beams for beam search {doc_input}",
    )
    input_group.add_argument(
        "--num_choices",
        type=int,
        help=f"Only for the multiple-choice task. Num choices {doc_input}",
    )
    input_group.add_argument(
        "--num_channels",
        type=int,
        help=f"Image tasks only. Number of channels {doc_input}",
    )
    input_group.add_argument(
        "--width",
        type=int,
        help=f"Image tasks only. Width {doc_input}",
    )
    input_group.add_argument(
        "--height",
        type=int,
        help=f"Image tasks only. Height {doc_input}",
    )
    input_group.add_argument(
        "--image_size",
        type=int,
        help="Image tasks only. Size (resolution) of each image.",
    )
    input_group.add_argument(
        "--patch_size",
        type=int,
        help="Image tasks only. Size (resolution) of patch.",
    )
    input_group.add_argument(
        "--num_images_per_prompt",
        type=int,
        help=f"Stable diffusion only. Number of images per prompt {doc_input}",
    )
    input_group.add_argument(
        "--audio_sequence_length",
        type=int,
        help=f"Audio tasks only. Audio sequence length {doc_input}",
    )

    # Optimization Level
    level_group = parser.add_mutually_exclusive_group()
    level_group.add_argument(
        "-O1",
        action="store_true",
        help="Enables the core performance optimizations in the compiler, while also minimizing compile time.",
    )
    level_group.add_argument(
        "-O2",
        action="store_true",
        help="[Default] Provides the best balance between model performance and compile time.",
    )
    level_group.add_argument(
        "-O3",
        action="store_true",
        help="May provide additional model execution performance but may incur longer compile times and higher host memory usage during model compilation.",
    )


class NeuronxExportCommand(BaseOptimumCLICommand):
    COMMAND = CommandInfo(name="neuron", help="Export PyTorch models to Neuronx compiled TorchScript models.")

    def __init__(
        self,
        subparsers: "_SubParsersAction",
        args: Optional["Namespace"] = None,
        command: Optional["CommandInfo"] = None,
        from_defaults_factory: bool = False,
        parser: Optional["ArgumentParser"] = None,
    ):
        super().__init__(
            subparsers, args=args, command=command, from_defaults_factory=from_defaults_factory, parser=parser
        )
        self.args_string = " ".join(sys.argv[3:])

    @staticmethod
    def parse_args(parser: "ArgumentParser"):
        return parse_args_neuronx(parser)

    def run(self):
        full_command = f"python3 -m optimum.exporters.neuron {self.args_string}"
        subprocess.run(full_command, shell=True, check=True)
