optimum_benchmark/backends/onnxruntime/backend.py (262 lines of code) (raw):
import os
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Dict
import torch
from hydra.utils import get_class
from onnxruntime import SessionOptions
from optimum.onnxruntime import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ORTOptimizer,
ORTQuantizer,
)
from optimum.onnxruntime.configuration import (
AutoCalibrationConfig,
AutoOptimizationConfig,
AutoQuantizationConfig,
CalibrationConfig,
OptimizationConfig,
QuantizationConfig,
)
from ...generators.dataset_generator import DatasetGenerator
from ...import_utils import is_accelerate_available, is_torch_distributed_available
from ..base import Backend
from ..transformers_utils import fast_weights_init
from .config import ORTConfig
from .utils import (
TASKS_TO_ORTMODELS,
TASKS_TO_ORTPIPELINES,
format_calibration_config,
format_quantization_config,
)
if is_accelerate_available():
from accelerate import Accelerator
if is_torch_distributed_available():
import torch.distributed
class ORTBackend(Backend[ORTConfig]):
NAME: str = "onnxruntime"
def __init__(self, config: ORTConfig) -> None:
super().__init__(config)
if self.config.library != "diffusers" and self.config.task in TASKS_TO_ORTMODELS:
self.ort_model_loader = get_class(TASKS_TO_ORTMODELS[self.config.task])
self.logger.info(f"Using ORTModel class {self.ort_model_loader.__name__}")
elif self.config.library == "diffusers" and self.config.task in TASKS_TO_ORTPIPELINES:
self.ort_model_loader = get_class(TASKS_TO_ORTPIPELINES[self.config.task])
self.logger.info(f"Using ORTDiffusionPipeline class {self.ort_model_loader.__name__}")
else:
raise NotImplementedError(f"ORTBackend does not support task {self.config.task}")
def validate_execution_provider(self) -> None:
if not self.pretrained_model.providers[0] == self.config.provider:
raise ValueError(
f"{self.config.provider} is not first in providers list: {self.pretrained_model.providers}"
)
def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()
if self.config.no_weights:
self.logger.info("\t+ Creating no weights ORTModel")
self.create_no_weights_model_fast()
self.logger.info("\t+ Loading no weights ORTModel")
self.load_ortmodel_with_no_weights()
else:
self.logger.info("\t+ Loading pretrained ORTModel")
self.load_ortmodel_from_pretrained()
if self.is_optimized or self.is_quantized:
original_model, self.config.model = self.config.model, self.pretrained_model.model_save_dir
if self.is_optimized:
self.logger.info("\t+ Applying ORT optimization")
self.optimize_onnx_files()
self.config.model = self.optimized_model
if self.is_quantized:
self.logger.info("\t+ Applying ORT quantization")
self.quantize_onnx_files()
self.config.model = self.quantized_model
if self.is_optimized or self.is_quantized:
original_export, self.config.export = self.config.export, False
self.logger.info("\t+ Loading optimized/quantized model")
self.load_ortmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model
self.logger.info("\t+ Validating requested Execution Provider")
self.validate_execution_provider()
self.logger.info("\t+ Cleaning up backend temporary directory")
self.tmpdir.cleanup()
def load_ortmodel_from_pretrained(self) -> None:
self.pretrained_model = self.ort_model_loader.from_pretrained(
self.config.model,
**self.config.model_kwargs,
**self.ortmodel_kwargs,
)
def load_ortmodel_with_no_weights(self) -> None:
with fast_weights_init():
original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix()
original_export, self.config.export = self.config.export, True
self.logger.info("\t+ Loading no weights ORTModel")
self.load_ortmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model
@property
def is_optimized(self) -> bool:
return (self.config.auto_optimization is not None) or self.config.optimization
@property
def is_quantized(self) -> bool:
return (self.config.auto_quantization is not None) or self.config.quantization
@property
def is_calibrated(self) -> bool:
return (self.config.auto_calibration is not None) or self.config.calibration
@property
def ortmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}
if self.config.export is not None:
kwargs["export"] = self.config.export
if self.config.provider is not None:
kwargs["provider"] = self.config.provider
if self.config.use_cache is not None:
kwargs["use_cache"] = self.config.use_cache
if self.config.use_merged is not None:
kwargs["use_merged"] = self.config.use_merged
if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = self.config.torch_dtype
if self.config.use_io_binding is not None:
kwargs["use_io_binding"] = self.config.use_io_binding
if self.config.session_options:
kwargs["session_options"] = SessionOptions()
for key, value in self.config.session_options.items():
setattr(kwargs["session_options"], key, value)
if self.config.provider_options:
kwargs["provider_options"] = self.config.provider_options
return kwargs
@property
def onnx_files_names(self):
assert os.path.isdir(self.config.model), f"{self.config.model} is not a directory"
if self.config.use_merged:
return [
model
for model in os.listdir(self.config.model)
if model not in [ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME] and model.endswith(".onnx")
]
else:
return [file for file in os.listdir(self.config.model) if file.endswith(".onnx")]
def optimize_onnx_files(self) -> None:
self.logger.info("\t+ Attempting optimization")
self.optimized_model = os.path.join(self.tmpdir.name, "optimized")
self.logger.info("\t+ Processing optimization config")
if self.config.auto_optimization is not None:
optimization_config = AutoOptimizationConfig.with_optimization_level(
optimization_level=self.config.auto_optimization,
for_gpu=(self.config.device == "cuda"),
**self.config.auto_optimization_config,
)
elif self.config.optimization:
optimization_config = OptimizationConfig(
optimize_for_gpu=(self.config.device == "cuda"), **self.config.optimization_config
)
self.logger.info("\t+ Creating optimizer")
optimizer = ORTOptimizer.from_pretrained(self.config.model, file_names=self.onnx_files_names)
self.logger.info("\t+ Optimizing ORTModel")
optimizer.optimize(
optimization_config,
save_dir=self.optimized_model,
# TODO: add support for these
use_external_data_format=None,
one_external_file=True,
file_suffix="",
)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(self.optimized_model)
if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(self.optimized_model)
def quantize_onnx_files(self) -> None:
self.logger.info("\t+ Attempting quantization")
self.quantized_model = f"{self.tmpdir.name}/quantized_model"
if self.is_calibrated and len(self.onnx_files_names) > 1:
raise NotImplementedError(
"Calibrated/Static Quantization is not supported for models with multiple components. "
f"Found {len(self.onnx_files_names)} components."
)
self.logger.info("\t+ Processing quantization config")
if self.config.auto_quantization is not None:
auto_quantization_config = format_quantization_config(self.config.auto_quantization_config)
auto_quantization_class = getattr(AutoQuantizationConfig, self.config.auto_quantization)
quantization_config = auto_quantization_class(**auto_quantization_config)
elif self.config.quantization:
quantization_config = format_quantization_config(self.config.quantization_config)
quantization_config = QuantizationConfig(**quantization_config)
if self.is_calibrated:
self.logger.info("\t+ Generating calibration dataset")
dataset_shapes = {"dataset_size": 2, "sequence_length": 2, "num_choices": 2}
calibration_dataset = DatasetGenerator(
task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes
)()
columns_to_be_removed = list(set(calibration_dataset.column_names) - set(self.pretrained_model.input_names))
calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed)
self.logger.info("\t+ Processing calibration config")
if self.config.auto_calibration is not None:
self.logger.info("\t+ Processing calibration config")
auto_calibration_method = getattr(AutoCalibrationConfig, self.config.auto_calibration)
calibration_config = auto_calibration_method(calibration_dataset, **self.config.auto_calibration_config)
elif self.config.calibration:
self.logger.info("\t+ Processing calibration config")
calibration_config = format_calibration_config(self.config.calibration_config)
calibration_config = CalibrationConfig(
dataset_name="calibration_dataset",
dataset_split=calibration_dataset.split,
dataset_num_samples=calibration_dataset.num_rows,
dataset_config_name=calibration_dataset.config_name,
**self.config.calibration_config,
)
for onnx_file_name in self.onnx_files_names:
self.logger.info(f"\t+ Creating quantizer for {onnx_file_name}")
quantizer = ORTQuantizer.from_pretrained(self.config.model, file_name=onnx_file_name)
if self.is_calibrated:
self.logger.info("\t+ Fitting calibration tensors range")
calibration_tensors_range = quantizer.fit(
dataset=calibration_dataset,
use_gpu=(self.config.device == "cuda"),
calibration_config=calibration_config,
operators_to_quantize=quantization_config.operators_to_quantize,
# TODO: add support for these (maybe)
use_external_data_format=False,
force_symmetric_range=False,
batch_size=1,
)
else:
calibration_tensors_range = None
self.logger.info("\t+ Quantizing model")
quantizer.quantize(
save_dir=self.quantized_model,
quantization_config=quantization_config,
calibration_tensors_range=calibration_tensors_range,
# TODO: add support for these (maybe)
use_external_data_format=False,
preprocessor=None,
file_suffix="",
)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(self.quantized_model)
if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(self.quantized_model)
@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(self.config.device)
for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)
return inputs
@torch.inference_mode()
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.forward(**inputs, **kwargs)
@torch.inference_mode()
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(**inputs, **kwargs)
@torch.inference_mode()
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(**inputs, **kwargs)
@torch.inference_mode()
def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model(**inputs, **kwargs)