optimum_benchmark/backends/onnxruntime/backend.py (262 lines of code) (raw):

import os from collections import OrderedDict from tempfile import TemporaryDirectory from typing import Any, Dict import torch from hydra.utils import get_class from onnxruntime import SessionOptions from optimum.onnxruntime import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ORTOptimizer, ORTQuantizer, ) from optimum.onnxruntime.configuration import ( AutoCalibrationConfig, AutoOptimizationConfig, AutoQuantizationConfig, CalibrationConfig, OptimizationConfig, QuantizationConfig, ) from ...generators.dataset_generator import DatasetGenerator from ...import_utils import is_accelerate_available, is_torch_distributed_available from ..base import Backend from ..transformers_utils import fast_weights_init from .config import ORTConfig from .utils import ( TASKS_TO_ORTMODELS, TASKS_TO_ORTPIPELINES, format_calibration_config, format_quantization_config, ) if is_accelerate_available(): from accelerate import Accelerator if is_torch_distributed_available(): import torch.distributed class ORTBackend(Backend[ORTConfig]): NAME: str = "onnxruntime" def __init__(self, config: ORTConfig) -> None: super().__init__(config) if self.config.library != "diffusers" and self.config.task in TASKS_TO_ORTMODELS: self.ort_model_loader = get_class(TASKS_TO_ORTMODELS[self.config.task]) self.logger.info(f"Using ORTModel class {self.ort_model_loader.__name__}") elif self.config.library == "diffusers" and self.config.task in TASKS_TO_ORTPIPELINES: self.ort_model_loader = get_class(TASKS_TO_ORTPIPELINES[self.config.task]) self.logger.info(f"Using ORTDiffusionPipeline class {self.ort_model_loader.__name__}") else: raise NotImplementedError(f"ORTBackend does not support task {self.config.task}") def validate_execution_provider(self) -> None: if not self.pretrained_model.providers[0] == self.config.provider: raise ValueError( f"{self.config.provider} is not first in providers list: {self.pretrained_model.providers}" ) def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.no_weights: self.logger.info("\t+ Creating no weights ORTModel") self.create_no_weights_model_fast() self.logger.info("\t+ Loading no weights ORTModel") self.load_ortmodel_with_no_weights() else: self.logger.info("\t+ Loading pretrained ORTModel") self.load_ortmodel_from_pretrained() if self.is_optimized or self.is_quantized: original_model, self.config.model = self.config.model, self.pretrained_model.model_save_dir if self.is_optimized: self.logger.info("\t+ Applying ORT optimization") self.optimize_onnx_files() self.config.model = self.optimized_model if self.is_quantized: self.logger.info("\t+ Applying ORT quantization") self.quantize_onnx_files() self.config.model = self.quantized_model if self.is_optimized or self.is_quantized: original_export, self.config.export = self.config.export, False self.logger.info("\t+ Loading optimized/quantized model") self.load_ortmodel_from_pretrained() self.config.export = original_export self.config.model = original_model self.logger.info("\t+ Validating requested Execution Provider") self.validate_execution_provider() self.logger.info("\t+ Cleaning up backend temporary directory") self.tmpdir.cleanup() def load_ortmodel_from_pretrained(self) -> None: self.pretrained_model = self.ort_model_loader.from_pretrained( self.config.model, **self.config.model_kwargs, **self.ortmodel_kwargs, ) def load_ortmodel_with_no_weights(self) -> None: with fast_weights_init(): original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix() original_export, self.config.export = self.config.export, True self.logger.info("\t+ Loading no weights ORTModel") self.load_ortmodel_from_pretrained() self.config.export = original_export self.config.model = original_model @property def is_optimized(self) -> bool: return (self.config.auto_optimization is not None) or self.config.optimization @property def is_quantized(self) -> bool: return (self.config.auto_quantization is not None) or self.config.quantization @property def is_calibrated(self) -> bool: return (self.config.auto_calibration is not None) or self.config.calibration @property def ortmodel_kwargs(self) -> Dict[str, Any]: kwargs = {} if self.config.export is not None: kwargs["export"] = self.config.export if self.config.provider is not None: kwargs["provider"] = self.config.provider if self.config.use_cache is not None: kwargs["use_cache"] = self.config.use_cache if self.config.use_merged is not None: kwargs["use_merged"] = self.config.use_merged if self.config.torch_dtype is not None: kwargs["torch_dtype"] = self.config.torch_dtype if self.config.use_io_binding is not None: kwargs["use_io_binding"] = self.config.use_io_binding if self.config.session_options: kwargs["session_options"] = SessionOptions() for key, value in self.config.session_options.items(): setattr(kwargs["session_options"], key, value) if self.config.provider_options: kwargs["provider_options"] = self.config.provider_options return kwargs @property def onnx_files_names(self): assert os.path.isdir(self.config.model), f"{self.config.model} is not a directory" if self.config.use_merged: return [ model for model in os.listdir(self.config.model) if model not in [ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME] and model.endswith(".onnx") ] else: return [file for file in os.listdir(self.config.model) if file.endswith(".onnx")] def optimize_onnx_files(self) -> None: self.logger.info("\t+ Attempting optimization") self.optimized_model = os.path.join(self.tmpdir.name, "optimized") self.logger.info("\t+ Processing optimization config") if self.config.auto_optimization is not None: optimization_config = AutoOptimizationConfig.with_optimization_level( optimization_level=self.config.auto_optimization, for_gpu=(self.config.device == "cuda"), **self.config.auto_optimization_config, ) elif self.config.optimization: optimization_config = OptimizationConfig( optimize_for_gpu=(self.config.device == "cuda"), **self.config.optimization_config ) self.logger.info("\t+ Creating optimizer") optimizer = ORTOptimizer.from_pretrained(self.config.model, file_names=self.onnx_files_names) self.logger.info("\t+ Optimizing ORTModel") optimizer.optimize( optimization_config, save_dir=self.optimized_model, # TODO: add support for these use_external_data_format=None, one_external_file=True, file_suffix="", ) if self.pretrained_processor is not None: self.pretrained_processor.save_pretrained(self.optimized_model) if self.pretrained_config is not None: self.pretrained_config.save_pretrained(self.optimized_model) def quantize_onnx_files(self) -> None: self.logger.info("\t+ Attempting quantization") self.quantized_model = f"{self.tmpdir.name}/quantized_model" if self.is_calibrated and len(self.onnx_files_names) > 1: raise NotImplementedError( "Calibrated/Static Quantization is not supported for models with multiple components. " f"Found {len(self.onnx_files_names)} components." ) self.logger.info("\t+ Processing quantization config") if self.config.auto_quantization is not None: auto_quantization_config = format_quantization_config(self.config.auto_quantization_config) auto_quantization_class = getattr(AutoQuantizationConfig, self.config.auto_quantization) quantization_config = auto_quantization_class(**auto_quantization_config) elif self.config.quantization: quantization_config = format_quantization_config(self.config.quantization_config) quantization_config = QuantizationConfig(**quantization_config) if self.is_calibrated: self.logger.info("\t+ Generating calibration dataset") dataset_shapes = {"dataset_size": 2, "sequence_length": 2, "num_choices": 2} calibration_dataset = DatasetGenerator( task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes )() columns_to_be_removed = list(set(calibration_dataset.column_names) - set(self.pretrained_model.input_names)) calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) self.logger.info("\t+ Processing calibration config") if self.config.auto_calibration is not None: self.logger.info("\t+ Processing calibration config") auto_calibration_method = getattr(AutoCalibrationConfig, self.config.auto_calibration) calibration_config = auto_calibration_method(calibration_dataset, **self.config.auto_calibration_config) elif self.config.calibration: self.logger.info("\t+ Processing calibration config") calibration_config = format_calibration_config(self.config.calibration_config) calibration_config = CalibrationConfig( dataset_name="calibration_dataset", dataset_split=calibration_dataset.split, dataset_num_samples=calibration_dataset.num_rows, dataset_config_name=calibration_dataset.config_name, **self.config.calibration_config, ) for onnx_file_name in self.onnx_files_names: self.logger.info(f"\t+ Creating quantizer for {onnx_file_name}") quantizer = ORTQuantizer.from_pretrained(self.config.model, file_name=onnx_file_name) if self.is_calibrated: self.logger.info("\t+ Fitting calibration tensors range") calibration_tensors_range = quantizer.fit( dataset=calibration_dataset, use_gpu=(self.config.device == "cuda"), calibration_config=calibration_config, operators_to_quantize=quantization_config.operators_to_quantize, # TODO: add support for these (maybe) use_external_data_format=False, force_symmetric_range=False, batch_size=1, ) else: calibration_tensors_range = None self.logger.info("\t+ Quantizing model") quantizer.quantize( save_dir=self.quantized_model, quantization_config=quantization_config, calibration_tensors_range=calibration_tensors_range, # TODO: add support for these (maybe) use_external_data_format=False, preprocessor=None, file_suffix="", ) if self.pretrained_processor is not None: self.pretrained_processor.save_pretrained(self.quantized_model) if self.pretrained_config is not None: self.pretrained_config.save_pretrained(self.quantized_model) @property def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs for key, value in inputs.items(): if isinstance(value, torch.Tensor): inputs[key] = value.to(self.config.device) for key in list(inputs.keys()): if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: inputs.pop(key) return inputs @torch.inference_mode() def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.forward(**inputs, **kwargs) @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model(**inputs, **kwargs)