#  Copyright 2021 The HuggingFace Team. All rights reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import copy
import os
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

import onnx

from ..utils import logging
from .transformations_utils import (
    _create_name_sharing_dict,
    _deduplicate_gather_matmul,
    _deduplicated_cross_model_initializers,
    _find_duplicate_initializers,
    _find_matching_initializers,
    _get_all_inputs,
    _get_onnx_opset,
    _get_weights_to_tie,
    _remove_redundant_initializers,
    _replace_input_names,
    _unify_onnx_outputs,
    cast_int64_tensorproto_to_int32,
)


if TYPE_CHECKING:
    import torch.nn as nn


logger = logging.get_logger()


def remove_duplicate_weights(model: onnx.ModelProto, inplace: bool = False) -> onnx.ModelProto:
    """
    Finds and removes duplicate weights in a model by keeping only unique weights, and make the duplicate values point
    to them.

    This function only removes duplicate weights that are exactly identical (e.g., not transposed).

    Args:
        model (`onnx.ModelProto`): The model to remove duplicates from.
        inplace (`bool`, defaults to False): Whether to perform this transformation inplace.

    Returns:
        `onnx.ModelProto`: The model without duplicates.
    """
    if not inplace:
        model = copy.deepcopy(model)
    duplicates = _find_duplicate_initializers(models=[model])
    name_sharing_dict = _create_name_sharing_dict(duplicates)

    _replace_input_names(models=[model], name_sharing_dict=name_sharing_dict)
    _remove_redundant_initializers(models=[model], name_sharing_dict=name_sharing_dict)

    return model


def remove_duplicate_weights_from_tied_info(
    onnx_model: onnx.ModelProto, torch_model: "nn.Module", tied_params: List[List[str]], save_path: str
):
    """
    Tries to remove potential duplicate ONNX initializers from the tied information in tied_params.

    Args:
        onnx_model (`onnx.ModelProto`):
            The ONNX model for which to tie potentially duplicate initializers.
        torch_model (`nn.Module`):
            The PyTorch model corresponding to the ONNX one.
        tied_params (`List[List[str]]`):
            A list of groups of torch parameters that are tied, i.e. shared. For them,
            the torch module shares the same pointer.
    """
    tied_params_with_op, tied_groups_to_tie, tied_groups_ignored = _get_weights_to_tie(tied_params, torch_model)

    if len(tied_groups_ignored) >= 1:
        logger.info(
            f"The groups of weights {tied_groups_ignored} will not be tied as either already tied or tying is not implemented."
        )

    initializer_name_to_idx = {}
    for idx, initializer in enumerate(onnx_model.graph.initializer):
        initializer_name_to_idx[initializer.name] = idx

    tied_groups_map = _find_matching_initializers(tied_params_with_op, onnx_model, initializer_name_to_idx)

    onnx_model = _deduplicate_gather_matmul(onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx)
    check_and_save_model(onnx_model, save_path=save_path)

    return onnx_model


def replace_atenops_to_gather(model: onnx.ModelProto) -> onnx.ModelProto:
    """
    Replaces broken ATenOp nodes back to Gather nodes.

    Args:
        model (`onnx.ModelProto`):
            The ONNX model to fix.

    Returns:
        `onnx.ModelProto`: The ONNX model fixed.
    """
    nodes = model.graph.node

    for node in nodes:
        if node.op_type in ["ATenOp", "ATen"]:
            op_num = node.name.split("_")[-1]
            new_node = onnx.helper.make_node(
                "Gather",
                name="Gather_" + op_num,
                inputs=[node.input[0], node.input[1]],
                outputs=node.output,
            )

            model.graph.node.remove(node)
            model.graph.node.insert(int(op_num), new_node)

    onnx.checker.check_model(model)
    return model


def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]):
    # We can check ModelProtos that are smaller than 2GB before saving them.
    # For larger models, we need to save them first and then check their save path.
    # https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb

    if model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
        # For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768
        try:
            onnx.checker.check_model(model)
        except Exception as e:
            if "No Op registered for" in str(e):
                pass
            else:
                raise e

    save_path = Path(save_path).as_posix()
    external_file_name = os.path.basename(save_path) + "_data"
    external_file_path = os.path.join(os.path.dirname(save_path), external_file_name)

    if save_path.endswith(".onnx") and os.path.isfile(save_path):
        os.remove(save_path)

    model_uses_external_data = False
    if os.path.isfile(external_file_path):
        model_uses_external_data = True
        os.remove(external_file_path)

    FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"

    onnx.save(
        model,
        save_path,
        save_as_external_data=model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA,
        all_tensors_to_one_file=True,
        location=external_file_name,
        convert_attribute=True,
        size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 100,
    )

    try:
        onnx.checker.check_model(save_path)
    except Exception as e:
        if "No Op registered for" in str(e):
            pass
        else:
            raise e


def merge_decoders(
    decoder: Union[onnx.ModelProto, Path, str],
    decoder_with_past: Union[onnx.ModelProto, Path, str],
    graph_name: str = "merged",
    producer_name: str = "optimum-onnx",
    save_path: Optional[Union[str, Path]] = None,
    strict: bool = True,
) -> onnx.ModelProto:
    """
    Fuses decoder ONNX model and decoder with past ONNX model into one ONNX model with if logic.

    Args:
        decoder (`Union[ModelProto, Path, str]`):
            Decoder ONNX model.
        decoder_with_past (`Union[ModelProto, Path, str]`):
            Decoder with past ONNX model.
        graph_name (`str`, defaults to `"merged"`):
            Name of the parent graph (graph of the control flow node).
        producer_name (`str`, defaults to `"optimum-onnx"`):
            Graph producer name.
        save_path (`Optional[Union[str, Path]]`, defaults to `None`):
            The path to save merged ONNX model. The model will be saved if the path is given.
        strict (`bool`, defaults to `True`):
            When set, the decoder and decoder_with_past are expected to have strictly the same number of outputs. When False,
            the decoder is allowed to have more outputs that decoder_with_past, in which case constant outputs are added to match
            the number of outputs.

    Returns:
        `~onnx.ModelProto`: The fused decoder ONNX model.
    """
    if isinstance(decoder, (str, Path)):
        decoder = Path(decoder).as_posix()
        decoder = onnx.load(decoder)

    if isinstance(decoder_with_past, (str, Path)):
        decoder_with_past = Path(decoder_with_past).as_posix()
        decoder_with_past = onnx.load(decoder_with_past)

    decoder_opset = _get_onnx_opset(decoder)
    decoder_with_past_opset = _get_onnx_opset(decoder_with_past)
    if decoder_opset != decoder_with_past_opset:
        raise ValueError(
            f"Decoder's opset is {decoder_opset}, but decoder with past's opset is {decoder_with_past_opset}. Make sure having the same opset before merging."
        )

    _unify_onnx_outputs(decoder, decoder_with_past, strict=strict)
    all_inputs = _get_all_inputs([decoder, decoder_with_past])

    # Replace the axis name `sequence_length` of the attention_mask input by `attention_mask_sequence_length`.
    # This is because the merged model `input_ids` and `attention_mask` inputs may not always have the same length on the 2nd axis.
    # In the first pass, `input_ids` and `attention_mask` are indeed of the same length, but in later pass `input_ids` is of length 1
    # while `attention_mask` is of length `past_sequence_length + 1`
    for _, inp in enumerate(all_inputs):
        if inp.name == "attention_mask":
            if inp.type.tensor_type.shape.dim[1].dim_param != "sequence_length":
                raise ValueError("Expected attention_mask second axis to be dynamic and named `sequence_length`.")
            inp.type.tensor_type.shape.dim[1].dim_param = "attention_mask_sequence_length"

    deduplicated_initializers = _deduplicated_cross_model_initializers([decoder, decoder_with_past], suffix=graph_name)

    # Keep initializers of dim 0 (or dim 1 + int32/int64) in subgraphs for readability purposes, and also because
    # ONNX Runtime breaks after optimization + merge if they are not
    decoder_initializers = []
    for initializer in decoder.graph.initializer:
        if len(initializer.dims) == 0 or (len(initializer.dims) == 1 and initializer.data_type in [6, 7]):
            decoder_initializers.append(initializer)

    decoder_with_past_initializers = []
    for initializer in decoder_with_past.graph.initializer:
        if len(initializer.dims) == 0 or (len(initializer.dims) == 1 and initializer.data_type in [6, 7]):
            decoder_with_past_initializers.append(initializer)

    # Make subgraphs
    no_past_branch = onnx.helper.make_graph(
        nodes=decoder.graph.node,
        name="no_past",
        inputs=[],
        outputs=decoder.graph.output,
        initializer=decoder_initializers,
    )

    with_past_branch = onnx.helper.make_graph(
        nodes=decoder_with_past.graph.node,
        name="with_past",
        inputs=[],
        outputs=decoder_with_past.graph.output,
        initializer=decoder_with_past_initializers,
    )

    # Merge subgraphs with a `If` node
    use_cache_branch = onnx.helper.make_tensor_value_info(
        name="use_cache_branch",
        elem_type=onnx.TensorProto.BOOL,
        shape=[1],
    )
    if_node = onnx.helper.make_node(
        "If",
        inputs=["use_cache_branch"],
        outputs=[output.name for output in no_past_branch.output],
        name="optimum::if",
        then_branch=with_past_branch,
        else_branch=no_past_branch,
    )
    merged_graph = onnx.helper.make_graph(
        nodes=[if_node],
        name=graph_name,
        inputs=all_inputs + [use_cache_branch],
        outputs=no_past_branch.output,
        initializer=deduplicated_initializers,
    )

    # Preserve imports from the decoder without/with past ONNX
    opset_imports = []
    opset_domains = set()
    for opset_import in list(decoder.opset_import) + list(decoder_with_past.opset_import):
        if opset_import.domain not in opset_domains:
            opset_imports.append(opset_import)
            opset_domains.add(opset_import.domain)

    # TODO: update IR version in the future.
    merged_model = onnx.helper.make_model_gen_version(
        merged_graph, producer_name=producer_name, opset_imports=opset_imports, ir_version=9
    )

    check_and_save_model(merged_model, save_path=save_path)

    return merged_model


def cast_slice_nodes_inputs_to_int32(model: onnx.ModelProto) -> onnx.ModelProto:
    """
    Convert node inputs of `Slice` nodes from int64 to int32, casting the out of range values.

    The constant node inputs are stored in `model.graph.node`, and the sole way to check which node
    they are consumed by is to iterate over nodes and check `node.input` for a match.

    Note that constant inputs to nodes as `Squeeze`, `Unsqueeze` can not be converted to int32, as the
    these operators explicitely expect int64 inputs according to ONNX specifications:
    https://github.com/onnx/onnx/blob/main/docs/Operators.md
    """
    map_input_node = {}
    map_node_inputs = {}

    for node in model.graph.node:
        for input_name in node.input:
            map_input_node[input_name] = {"op_type": node.op_type, "node_name": node.name}
        map_node_inputs[node.name] = node.input

    for node in model.graph.node:
        if (
            node.op_type == "Constant"
            and node.attribute[0].t.data_type == 7  # int64
            and f"{node.name}_output_0" in map_input_node
            and map_input_node[node.name + "_output_0"]["op_type"] == "Slice"
        ):
            logger.debug(f"Converting {node.name} to int32")

            # `Slice` node is homogeneous (requires parameters of same type), hence cast to int32 only if all of its inputs are constants
            # refer to onnx/defs/schema.h
            cast = all(
                "Constant" in inp for inp in map_node_inputs[map_input_node[node.name + "_output_0"]["node_name"]][1:]
            )
            cast_int64_tensorproto_to_int32(node.attribute[0].t, cast=cast)

    return model
