optimum/onnx/utils.py (61 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List, Tuple, Union
import onnx
from onnx.external_data_helper import ExternalDataInfo, _get_initializer_tensors
def _get_onnx_external_constants(model: onnx.ModelProto) -> List[str]:
external_constants = []
for node in model.graph.node:
if node.op_type == "Constant":
for attribute in node.attribute:
external_datas = attribute.t.external_data
for external_data in external_datas:
external_constants.append(external_data.value)
return external_constants
def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]:
"""
Gets the paths of the external data tensors in the model.
Note: make sure you load the model with load_external_data=False.
"""
model_tensors = _get_initializer_tensors(model)
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return model_tensors_ext
def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tuple[List[Path], List[str]]:
"""
Gets external data paths from the model and add them to the list of files to copy.
"""
model_paths = src_paths.copy()
for idx, model_path in enumerate(model_paths):
model = onnx.load(str(model_path), load_external_data=False)
model_tensors = _get_initializer_tensors(model)
# filter out tensors that are not external data
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
if len(set(model_tensors_ext)) == 1:
# if external data was saved in a single file
src_paths.append(model_path.parent / model_tensors_ext[0])
dst_paths.append(dst_paths[idx].parent / model_tensors_ext[0])
else:
# if external data doesnt exist or was saved in multiple files
src_paths.extend([model_path.parent / tensor_name for tensor_name in model_tensors_ext])
dst_paths.extend(dst_paths[idx].parent / tensor_name for tensor_name in model_tensors_ext)
return src_paths, dst_paths
def _get_model_external_data_paths(model_path: Path) -> List[Path]:
"""
Gets external data paths from the model.
"""
onnx_model = onnx.load(str(model_path), load_external_data=False)
model_tensors = _get_initializer_tensors(onnx_model)
# filter out tensors that are not external data
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return list({model_path.parent / tensor_name for tensor_name in model_tensors_ext})
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""
Checks if the model uses external data.
"""
model_tensors = _get_initializer_tensors(model)
return any(
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
for tensor in model_tensors
)
def has_onnx_input(model: Union[onnx.ModelProto, Path, str], input_name: str) -> bool:
"""
Checks if the model has a specific input.
"""
if isinstance(model, (str, Path)):
model = Path(model).as_posix()
model = onnx.load(model, load_external_data=False)
for input in model.graph.input:
if input.name == input_name:
return True
return False