optimum_benchmark/backends/onnxruntime/utils.py (25 lines of code) (raw):
from typing import Any, Dict
from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantizationMode, QuantType
from optimum.pipelines import ORT_SUPPORTED_TASKS
TASKS_TO_ORTMODELS = {
task: f"optimum.onnxruntime.{task_dict['class'][0].__name__}" for task, task_dict in ORT_SUPPORTED_TASKS.items()
}
TASKS_TO_ORTPIPELINES = {
"inpainting": "optimum.onnxruntime.ORTPipelineForInpainting",
"text-to-image": "optimum.onnxruntime.ORTPipelineForText2Image",
"image-to-image": "optimum.onnxruntime.ORTPipelineForImage2Image",
}
def format_calibration_config(calibration_config: Dict[str, Any]) -> None:
if calibration_config.get("method", None) is not None:
calibration_config["method"] = CalibrationMethod[calibration_config["method"]]
return calibration_config
def format_quantization_config(quantization_config: Dict[str, Any]) -> None:
"""Format the quantization dictionary for onnxruntime."""
# the conditionals are here because some quantization strategies don't have all the options
if quantization_config.get("format", None) is not None:
quantization_config["format"] = QuantFormat.from_string(quantization_config["format"])
if quantization_config.get("mode", None) is not None:
quantization_config["mode"] = QuantizationMode.from_string(quantization_config["mode"])
if quantization_config.get("activations_dtype", None) is not None:
quantization_config["activations_dtype"] = QuantType.from_string(quantization_config["activations_dtype"])
if quantization_config.get("weights_dtype", None) is not None:
quantization_config["weights_dtype"] = QuantType.from_string(quantization_config["weights_dtype"])
return quantization_config