scripts/utils.py (51 lines of code) (raw):

import onnx from typing import Optional, Union from pathlib import Path import os import logging logger = logging.getLogger(__name__) # https://github.com/onnx/onnx/pull/6556 MAXIMUM_PROTOBUF = 2147483648 # 2GiB def strict_check_model(model_or_path: Union[onnx.ModelProto, str, Path]): try: onnx.checker.check_model(model_or_path, full_check=True) except Exception as e: if "No Op registered for" in str(e): pass else: raise e def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]): # for large models, a path must be provided instead of a ModelProto: # https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb if model.ByteSize() < MAXIMUM_PROTOBUF: # For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768 strict_check_model(model) if save_path: # Overwrite. save_path = Path(save_path).as_posix() external_file_name = os.path.basename(save_path) + "_data" # path/to/model.onnx_data external_path = os.path.join(os.path.dirname(save_path), external_file_name) if save_path.endswith(".onnx") and os.path.isfile(save_path): os.remove(save_path) if os.path.isfile(external_path): # The new model may be below the maximum protobuf size, overwritting a model that was larger. Hence this os.remove. os.remove(external_path) onnx.save( model, save_path, convert_attribute=True, ) elif save_path is not None: # path/to/model.onnx save_path = Path(save_path).as_posix() external_file_name = os.path.basename(save_path) + "_data" # path/to/model.onnx_data external_path = os.path.join(os.path.dirname(save_path), external_file_name) if save_path.endswith(".onnx") and os.path.isfile(save_path): os.remove(save_path) if os.path.isfile(external_path): os.remove(external_path) onnx.save( model, save_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_file_name, convert_attribute=True, ) else: logger.info( "Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given." )