#  Copyright 2022 The HuggingFace Team. All rights reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from pathlib import Path
from typing import List, Tuple, Union

import onnx
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors


def _get_onnx_external_constants(model: onnx.ModelProto) -> List[str]:
    external_constants = []

    for node in model.graph.node:
        if node.op_type == "Constant":
            for attribute in node.attribute:
                external_datas = attribute.t.external_data
                for external_data in external_datas:
                    external_constants.append(external_data.value)

    return external_constants


def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]:
    """
    Gets the paths of the external data tensors in the model.
    Note: make sure you load the model with load_external_data=False.
    """
    model_tensors = _get_initializer_tensors(model)
    model_tensors_ext = [
        ExternalDataInfo(tensor).location
        for tensor in model_tensors
        if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
    ]
    return model_tensors_ext


def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tuple[List[Path], List[str]]:
    """
    Gets external data paths from the model and add them to the list of files to copy.
    """
    model_paths = src_paths.copy()
    for idx, model_path in enumerate(model_paths):
        model = onnx.load(str(model_path), load_external_data=False)
        model_tensors = _get_initializer_tensors(model)
        # filter out tensors that are not external data
        model_tensors_ext = [
            ExternalDataInfo(tensor).location
            for tensor in model_tensors
            if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
        ]
        if len(set(model_tensors_ext)) == 1:
            # if external data was saved in a single file
            src_paths.append(model_path.parent / model_tensors_ext[0])
            dst_paths.append(dst_paths[idx].parent / model_tensors_ext[0])
        else:
            # if external data doesnt exist or was saved in multiple files
            src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext])
            dst_paths.extend(dst_paths[idx].parent / tensor_name for tensor_name in model_tensors_ext)
    return src_paths, dst_paths


def _get_model_external_data_paths(model_path: Path) -> List[Path]:
    """
    Gets external data paths from the model.
    """

    onnx_model = onnx.load(str(model_path), load_external_data=False)
    model_tensors = _get_initializer_tensors(onnx_model)
    # filter out tensors that are not external data
    model_tensors_ext = [
        ExternalDataInfo(tensor).location
        for tensor in model_tensors
        if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
    ]
    return list({model_path.parent / tensor_name for tensor_name in model_tensors_ext})


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
    """
    Checks if the model uses external data.
    """
    model_tensors = _get_initializer_tensors(model)
    return any(
        tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
        for tensor in model_tensors
    )


def has_onnx_input(model: Union[onnx.ModelProto, Path, str], input_name: str) -> bool:
    """
    Checks if the model has a specific input.
    """
    if isinstance(model, (str, Path)):
        model = Path(model).as_posix()
        model = onnx.load(model, load_external_data=False)

    for input in model.graph.input:
        if input.name == input_name:
            return True
    return False
