optimum/onnxruntime/optimization.py (193 lines of code) (raw):

# Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main class for performing graph optimization with ONNX Runtime.""" import gc import os from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union import onnx from onnx import load_model from transformers import GenerationConfig from transformers.models.auto.configuration_auto import AutoConfig from onnxruntime.transformers.onnx_model_bert import BertOnnxModel from onnxruntime.transformers.optimizer import optimize_model from ..onnx.utils import check_model_uses_external_data from ..utils import CONFIG_NAME, NormalizedConfigManager, logging from ..utils.save_utils import maybe_save_preprocessors from .configuration import OptimizationConfig, ORTConfig from .modeling_decoder import ORTModelForCausalLM from .modeling_ort import ORTModel from .modeling_seq2seq import ORTModelForConditionalGeneration from .utils import ONNX_WEIGHTS_NAME, ORTConfigManager if TYPE_CHECKING: from transformers import PretrainedConfig logger = logging.get_logger() class ORTOptimizer: """ Handles the ONNX Runtime optimization process for models shared on huggingface.co/models. """ def __init__(self, onnx_model_path: List[os.PathLike], config: "PretrainedConfig", from_ortmodel: bool = False): """ Args: onnx_model_path (`List[os.PathLike]`): The paths of the onnx models to optimize. config ([`~transformers.PretrainedConfig`]): An instance of the configuration associated to the model to optimize. from_ortmodel (`bool`, defaults to `False`): Whether the model being optimized is already loaded into an ORTModel, or if it was passed from disk. """ super().__init__() self.onnx_model_path = onnx_model_path self.config = config self.model_type = self.config.model_type self.from_ortmodel = from_ortmodel try: self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.model_type)(self.config) except KeyError: raise NotImplementedError( f"Tried to use ORTOptimizer for the model type {self.model_type}, but it is not available yet. Please open an issue" " or submit a PR at https://github.com/huggingface/optimum." ) @classmethod def from_pretrained( cls, model_or_path: Union[str, os.PathLike, ORTModel], file_names: Optional[List[str]] = None ) -> "ORTOptimizer": """ Args: model_or_path (`Union[str, os.PathLike, ORTModel]`): The path to a local directory hosting the model to optimize or an instance of an `ORTModel` to quantize. Can be either: - A path to a local *directory* containing the model to optimize. - An instance of [`~optimum.onnxruntime.ORTModel`]. file_names(`Optional[List[str]]`, defaults to `None`): The list of file names of the models to optimize. """ onnx_model_path = [] config = None if isinstance(model_or_path, ORTModel): from_ortmodel = True if isinstance(model_or_path, ORTModelForConditionalGeneration): onnx_model_path += [ model_or_path.encoder.path, model_or_path.decoder.path, ] # Add the decoder with past key/values if present if model_or_path.decoder_with_past is not None: onnx_model_path.append(model_or_path.decoder_with_past.path) elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged: raise NotImplementedError( "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged. " "Please re-export your model. This can be done by using the optimum-cli ONNX export tool or `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`." ) else: onnx_model_path.append(model_or_path.path) config = model_or_path.config elif os.path.isdir(model_or_path): from_ortmodel = False file_names = [ONNX_WEIGHTS_NAME] if file_names is None else file_names model_or_path = Path(model_or_path) if CONFIG_NAME not in os.listdir(model_or_path): raise ValueError(f"The local directory does not contain the configuration file {CONFIG_NAME}.") config = AutoConfig.from_pretrained(model_or_path) for file_name in file_names: onnx_model_path.append(model_or_path.joinpath(file_name)) else: raise ValueError(f"Unable to load the model from {model_or_path}.") return cls(onnx_model_path, config=config, from_ortmodel=from_ortmodel) def optimize( self, optimization_config: OptimizationConfig, save_dir: Union[str, os.PathLike], file_suffix: Optional[str] = "optimized", use_external_data_format: Optional[bool] = None, one_external_file: bool = True, ): """ Optimizes a model given the optimization specifications defined in `optimization_config`. Args: optimization_config ([`~optimum.onnxruntime.OptimizationConfig`]): The configuration containing the parameters related to optimization. save_dir (`Union[str, os.PathLike]`): The path used to save the optimized model. file_suffix (`str`, defaults to `"optimized"`): The file suffix used to save the optimized model. use_external_data_format (`Optional[bool]`, defaults to `None`): Whether to use external data format to store model of size >= 2Gb. This argument is deprecated. one_external_file (`bool`, defaults to `True`): When `use_external_data_format=True`, whether to save all tensors to one external file. If False, save each tensor to a file named with the tensor name. """ if use_external_data_format is not None: logger.warning( "The argument use_external_data_format in the ORTOptimizer.optimize() method is deprecated and will" " be removed in optimum 2.0." ) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) ORTConfigManager.check_optimization_supported_model(self.model_type, optimization_config) model_type = ORTConfigManager.get_model_ort_type(self.config.model_type) optimization_options = optimization_config.create_fusion_options(model_type) logger.info("Optimizing model...") # TODO: this is quite inefficient as we load in memory if models are <2GB without external data model_uses_external_data = False for model_path in self.onnx_model_path: # check if external data was exported onnx_model = onnx.load(str(model_path), load_external_data=False) if check_model_uses_external_data(onnx_model) is True: model_uses_external_data = True break del onnx_model gc.collect() # Create and save the configuration summarizing all the parameters related to optimization ort_config = ORTConfig( optimization=optimization_config, use_external_data_format=model_uses_external_data, one_external_file=one_external_file, ) for model_path in self.onnx_model_path: suffix = f"_{file_suffix}" if file_suffix else "" output_path = save_dir.joinpath(f"{model_path.stem}{suffix}").with_suffix(model_path.suffix) try: optimizer = optimize_model( model_path.as_posix(), model_type, self.normalized_config.num_attention_heads, self.normalized_config.hidden_size, opt_level=optimization_config.optimization_level, optimization_options=optimization_options, use_gpu=optimization_config.optimize_for_gpu, only_onnxruntime=not optimization_config.enable_transformers_specific_optimizations, ) if optimization_config.fp16: if model_uses_external_data: # Refer to https://github.com/microsoft/onnxruntime/blob/v1.15.0/onnxruntime/python/tools/transformers/float16.py#L204 # The ONNX infer_shapes_path method should be used instead of infer_shapes # for models >= 2 GB, and it expects a model written to disk. # Note that convert_float_to_float16 then overwrites optimizer.model as the # new ModelProto. optimizer.save_model_to_file( output_path.as_posix(), use_external_data_format=model_uses_external_data, all_tensors_to_one_file=one_external_file, ) optimizer.model = output_path.as_posix() # keep_io_types to keep inputs/outputs as float32 optimizer.convert_float_to_float16( use_symbolic_shape_infer=not optimization_config.disable_shape_inference, keep_io_types=True ) except Exception as e: if "Incomplete symbolic shape inference" in str(e): err = RuntimeError( f"{str(e)}. Try to set `disable_shape_inference=True` in your optimization configuration." ) raise err from e raise # TODO: ORT save_model_to_file will save as `.data` although we save as `.onnx_data` in the export optimizer.save_model_to_file( output_path.as_posix(), use_external_data_format=model_uses_external_data, all_tensors_to_one_file=one_external_file, ) # if loading from disk and saving in the same repository, remove previous external data if Path(model_path.as_posix() + "_data").is_file() and self.from_ortmodel is False: os.remove(model_path.as_posix() + "_data") # Save the model configuration self.config.save_pretrained(save_dir) ort_config.save_pretrained(save_dir) maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir) try: generation_config = GenerationConfig.from_pretrained(self.onnx_model_path[0].parent) generation_config.save_pretrained(save_dir) except Exception: pass logger.info( f"Optimized model saved at: {save_dir} (external data format: " f"{model_uses_external_data}; saved all tensor to one file: " f"{one_external_file})" ) return Path(save_dir) @staticmethod def get_fused_operators(onnx_model_path: Union[str, os.PathLike]) -> Dict[str, int]: """ Computes the dictionary mapping the name of the fused operators to their number of apparition in the model. Args: onnx_model_path (`Union[str, os.PathLike]`): Path of the ONNX model. Returns: The dictionary mapping the name of the fused operators to their number of apparition in the model. """ onnx_optimized_model = BertOnnxModel(load_model(onnx_model_path)) fused_operator = onnx_optimized_model.get_fused_operator_statistics() logger.info( f"The following operators were fused : { ', '.join([k for k,v in fused_operator.items() if v > 0])}" ) return {k: v for k, v in fused_operator.items() if v > 0} @staticmethod def get_nodes_number_difference( onnx_model_path: Union[str, os.PathLike], onnx_optimized_model_path: Union[str, os.PathLike] ) -> int: """ Compute the difference in the number of nodes between the original and the optimized model. Args: onnx_model_path (`Union[str, os.PathLike]`): Path of the ONNX model. onnx_optimized_model_path (`Union[str, os.PathLike]`): Path of the optimized ONNX model. Returns: The difference in the number of nodes between the original and the optimized model. """ onnx_model = BertOnnxModel(load_model(onnx_model_path)) onnx_optimized_model = BertOnnxModel(load_model(onnx_optimized_model_path)) # Information in the number of nodes decrease resulting from optimization nodes_number_onnx_model = len(onnx_model.nodes()) nodes_number_onnx_optimized_model = len(onnx_optimized_model.nodes()) difference_nodes_number = nodes_number_onnx_model - nodes_number_onnx_optimized_model logger.info( f"There are {nodes_number_onnx_model} nodes before optimization and {nodes_number_onnx_optimized_model}" f"nodes after. The number of nodes removed is {difference_nodes_number}" ) return difference_nodes_number @staticmethod def get_operators_difference( onnx_model_path: Union[str, os.PathLike], onnx_optimized_model_path: Union[str, os.PathLike] ) -> Dict[str, int]: """ Compute the dictionary mapping the operators name to the difference in the number of corresponding nodes between the original and the optimized model. Args: onnx_model_path (`Union[str, os.PathLike]`): Path of the ONNX model. onnx_optimized_model_path (`Union[str, os.PathLike]`): Path of the optimized ONNX model. Returns: The dictionary mapping the operators name to the difference in the number of corresponding nodes between the original and the optimized model. """ onnx_model = BertOnnxModel(load_model(onnx_model_path)) onnx_optimized_model = BertOnnxModel(load_model(onnx_optimized_model_path)) def nodes_difference_given_type(op_type): onnx_model_nodes_with_op_type = len(onnx_model.get_nodes_by_op_type(op_type)) onnx_optimized_model_nodes_with_op_type = len(onnx_optimized_model.get_nodes_by_op_type(op_type)) return onnx_model_nodes_with_op_type - onnx_optimized_model_nodes_with_op_type # Compute operators difference between the original and the optimized models op_types = set() for model in [onnx_model, onnx_optimized_model]: for node in model.nodes(): op_types.add(node.op_type) operators_difference = {op_type: nodes_difference_given_type(op_type) for op_type in op_types} return {k: v for k, v in operators_difference.items() if v != 0}