optimum/onnxruntime/quantization.py (298 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes handling quantization with ONNX Runtime."""
import logging
import os
import warnings
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import onnx
from packaging.version import Version, parse
from transformers import AutoConfig
from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.qdq_quantizer import QDQQuantizer
from optimum.utils.import_utils import requires_backends
from ..quantization_base import OptimumQuantizer
from ..utils.save_utils import maybe_save_preprocessors
from . import ORTQuantizableOperator
from .configuration import CalibrationConfig, ORTConfig, QuantizationConfig
from .modeling_ort import ORTModel
from .modeling_seq2seq import ORTModelForConditionalGeneration
from .preprocessors import QuantizationPreprocessor
if TYPE_CHECKING:
from datasets import Dataset
from transformers import PretrainedConfig
LOGGER = logging.getLogger(__name__)
class ORTCalibrationDataReader(CalibrationDataReader):
__slots__ = ["batch_size", "dataset", "_dataset_iter"]
def __init__(self, dataset: "Dataset", batch_size: int = 1):
if dataset is None:
raise ValueError("Provided dataset is None.")
if batch_size <= 0:
raise ValueError(f"Provided batch_size should be >= 1 (got: {batch_size}).")
self.dataset = dataset
self.batch_size = batch_size
self._dataset_iter = iter(self.dataset)
def get_next(self):
featurized_samples = None
try:
if self.batch_size == 1:
featurized_samples = {key: [value] for key, value in next(self._dataset_iter).items()}
else:
featurized_samples = defaultdict(list)
for _ in range(self.batch_size):
sample = next(self._dataset_iter)
for name, value in sample.items():
featurized_samples[name] += [value]
except StopIteration:
pass
if featurized_samples is not None and len(featurized_samples) > 0:
return featurized_samples
return None
class ORTQuantizer(OptimumQuantizer):
"""
Handles the ONNX Runtime quantization process for models shared on huggingface.co/models.
"""
def __init__(self, onnx_model_path: Path, config: Optional["PretrainedConfig"] = None):
"""
Args:
onnx_model_path (`Path`):
Path to the onnx model files you want to quantize.
config (`Optional[PretrainedConfig]`, defaults to `None`):
The configuration of the model.
"""
super().__init__()
self.onnx_model_path = onnx_model_path
self.config = config
if self.config is None:
try:
self.config = AutoConfig.from_pretrained(self.onnx_model_path.parent)
except (OSError, ValueError):
LOGGER.warning(
f"Could not load the config for {self.onnx_model_path} automatically, this might make "
"the quantized model harder to use because it will not be able to be loaded by an ORTModel without "
"having to specify the configuration explicitly."
)
self._calibrator = None
@classmethod
def from_pretrained(
cls,
model_or_path: Union["ORTModel", str, Path],
file_name: Optional[str] = None,
) -> "ORTQuantizer":
"""
Instantiates a `ORTQuantizer` from an ONNX model file or an `ORTModel`.
Args:
model_or_path (`Union[ORTModel, str, Path]`):
Can be either:
- A path to a saved exported ONNX Intermediate Representation (IR) model, e.g., `./my_model_directory/.
- Or an `ORTModelForXX` class, e.g., `ORTModelForQuestionAnswering`.
file_name(`Optional[str]`, defaults to `None`):
Overwrites the default model file name from `"model.onnx"` to `file_name`.
This allows you to load different model files from the same repository or directory.
Returns:
An instance of `ORTQuantizer`.
"""
ort_quantizer_error_message = "ORTQuantizer does not support multi-file quantization. Please create separate ORTQuantizer instances for each model/file, by passing the argument `file_name` to ORTQuantizer.from_pretrained()."
if isinstance(model_or_path, str):
model_or_path = Path(model_or_path)
path = None
config = None
if isinstance(model_or_path, ORTModelForConditionalGeneration):
raise NotImplementedError(ort_quantizer_error_message)
elif isinstance(model_or_path, Path) and file_name is None:
onnx_files = list(model_or_path.glob("*.onnx"))
if len(onnx_files) == 0:
raise FileNotFoundError(f"Could not find any ONNX model file in {model_or_path}")
elif len(onnx_files) > 1:
raise RuntimeError(
f"Found too many ONNX model files in {model_or_path}. {ort_quantizer_error_message}"
)
file_name = onnx_files[0].name
if isinstance(model_or_path, ORTModel):
path = Path(model_or_path.model._model_path)
config = model_or_path.config
elif os.path.isdir(model_or_path):
path = Path(model_or_path) / file_name
else:
raise ValueError(f"Unable to load model from {model_or_path}.")
return cls(path, config=config)
def fit(
self,
dataset: "Dataset",
calibration_config: CalibrationConfig,
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
operators_to_quantize: Optional[List[str]] = None,
batch_size: int = 1,
use_external_data_format: bool = False,
use_gpu: bool = False,
force_symmetric_range: bool = False,
) -> Dict[str, Tuple[float, float]]:
"""
Performs the calibration step and computes the quantization ranges.
Args:
dataset (`Dataset`):
The dataset to use when performing the calibration step.
calibration_config ([`~CalibrationConfig`]):
The configuration containing the parameters related to the calibration step.
onnx_augmented_model_name (`Union[str, Path]`, defaults to `"augmented_model.onnx"`):
The path used to save the augmented model used to collect the quantization ranges.
operators_to_quantize (`Optional[List[str]]`, defaults to `None`):
List of the operators types to quantize.
batch_size (`int`, defaults to 1):
The batch size to use when collecting the quantization ranges values.
use_external_data_format (`bool`, defaults to `False`):
Whether to use external data format to store model which size is >= 2Gb.
use_gpu (`bool`, defaults to `False`):
Whether to use the GPU when collecting the quantization ranges values.
force_symmetric_range (`bool`, defaults to `False`):
Whether to make the quantization ranges symmetric.
Returns:
The dictionary mapping the nodes name to their quantization ranges.
"""
# If a dataset is provided, then we are in a static quantization mode
LOGGER.info(
f"Using static quantization schema ("
f"dataset: {calibration_config.dataset_name}, method: {calibration_config.method}"
f")"
)
self.partial_fit(
dataset,
calibration_config,
onnx_augmented_model_name,
operators_to_quantize,
batch_size,
use_external_data_format,
use_gpu,
force_symmetric_range,
)
return self.compute_ranges()
def partial_fit(
self,
dataset: "Dataset",
calibration_config: CalibrationConfig,
onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx",
operators_to_quantize: Optional[List[str]] = None,
batch_size: int = 1,
use_external_data_format: bool = False,
use_gpu: bool = False,
force_symmetric_range: bool = False,
):
"""
Performs the calibration step and collects the quantization ranges without computing them.
Args:
dataset (`Dataset`):
The dataset to use when performing the calibration step.
calibration_config (`CalibrationConfig`):
The configuration containing the parameters related to the calibration step.
onnx_augmented_model_name (`Union[str, Path]`, defaults to `"augmented_model.onnx"`):
The path used to save the augmented model used to collect the quantization ranges.
operators_to_quantize (`Optional[List[str]]`, defaults to `None`):
List of the operators types to quantize.
batch_size (`int`, defaults to 1):
The batch size to use when collecting the quantization ranges values.
use_external_data_format (`bool`, defaults to `False`):
Whether uto se external data format to store model which size is >= 2Gb.
use_gpu (`bool`, defaults to `False`):
Whether to use the GPU when collecting the quantization ranges values.
force_symmetric_range (`bool`, defaults to `False`):
Whether to make the quantization ranges symmetric.
"""
# If no calibrator, then create one
if calibration_config.method is not None:
LOGGER.info(f"Creating calibrator: {calibration_config.method}({calibration_config})")
self._calibrator = calibration_config.create_calibrator(
onnx_model_path=self.onnx_model_path.as_posix(),
use_external_data_format=use_external_data_format,
augmented_model_name=onnx_augmented_model_name,
operators_to_quantize=operators_to_quantize,
force_symmetric_range=force_symmetric_range,
)
if use_gpu:
self._calibrator.set_execution_providers(execution_providers=["CUDAExecutionProvider"])
LOGGER.info("Collecting tensors statistics...")
reader = ORTCalibrationDataReader(dataset, batch_size)
self._calibrator.collect_data(reader)
def compute_ranges(self) -> Dict[str, Tuple[float, float]]:
"""
Computes the quantization ranges.
Returns:
The dictionary mapping the nodes name to their quantization ranges.
"""
if self._calibrator is None:
raise ValueError(
"Calibrator is None, please call `partial_fit` or `fit` method at least ones to compute ranges."
)
LOGGER.info("Computing calibration ranges")
if parse(ort_version) >= Version("1.16.0"):
return self._calibrator.compute_data()
return self._calibrator.compute_range()
def quantize(
self,
quantization_config: QuantizationConfig,
save_dir: Union[str, Path],
file_suffix: Optional[str] = "quantized",
calibration_tensors_range: Optional[Dict[str, Tuple[float, float]]] = None,
use_external_data_format: bool = False,
preprocessor: Optional[QuantizationPreprocessor] = None,
) -> Path:
"""
Quantizes a model given the optimization specifications defined in `quantization_config`.
Args:
quantization_config (`QuantizationConfig`):
The configuration containing the parameters related to quantization.
save_dir (`Union[str, Path]`):
The directory where the quantized model should be saved.
file_suffix (`Optional[str]`, defaults to `"quantized"`):
The file_suffix used to save the quantized model.
calibration_tensors_range (`Optional[Dict[str, Tuple[float, float]]]`, defaults to `None`):
The dictionary mapping the nodes name to their quantization ranges, used and required only when applying static quantization.
use_external_data_format (`bool`, defaults to `False`):
Whether to use external data format to store model which size is >= 2Gb.
preprocessor (`Optional[QuantizationPreprocessor]`, defaults to `None`):
The preprocessor to use to collect the nodes to include or exclude from quantization.
Returns:
The path of the resulting quantized model.
"""
use_qdq = quantization_config.is_static and quantization_config.format == QuantFormat.QDQ
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
if quantization_config.is_static and calibration_tensors_range is None:
raise ValueError(
"Requested static quantization in the QuantizationConfig, but no calibration ranges were provided. Please run calibration first using the quantizer fit method, or use dynamic quantization."
)
if not quantization_config.is_static:
if quantization_config.mode != QuantizationMode.IntegerOps:
LOGGER.warning(
f"ONNX Runtime dynamic quantization mode should be QuantizationMode.IntegerOps "
f"(got: {quantization_config.mode})."
)
if quantization_config.activations_dtype != QuantType.QUInt8:
LOGGER.warning(
f"ONNX Runtime dynamic quantization activations data type should be QuantType.QUInt8 "
f"(got: {quantization_config.activations_dtype})."
)
LOGGER.info(
f"Creating {'static' if quantization_config.is_static else 'dynamic'} quantizer: {quantization_config}"
)
if preprocessor is not None:
LOGGER.info("Preprocessor detected, collecting nodes to include/exclude")
nodes_to_quantize, nodes_to_exclude = preprocessor.collect(self.onnx_model_path)
nodes_to_quantize.update(quantization_config.nodes_to_quantize)
nodes_to_exclude.update(quantization_config.nodes_to_exclude)
quantization_config.nodes_to_quantize = list(nodes_to_quantize)
quantization_config.nodes_to_exclude = list(nodes_to_exclude)
has_subgraphs = False
onnx_model = onnx.load(Path(self.onnx_model_path).as_posix())
for node in onnx_model.graph.node:
if node.op_type in ["If", "Loop", "Scan", "SequenceMap"]:
has_subgraphs = True
break
if has_subgraphs:
if quantization_config.is_static:
raise NotImplementedError("Static quantization is currently not supported for models with subgraphs.")
if parse(ort_version) == Version("1.16.0"):
raise ValueError(
"ONNX Runtime version v1.16.0 is not compatible with quantization for models with subgraphs, please downgrade to 1.15.1 or upgrade to a higher version. Reference: https://github.com/microsoft/onnxruntime/pull/17651"
)
quantizer_factory = QDQQuantizer if use_qdq else ONNXQuantizer
# TODO: maybe this logic can be moved to a method in the configuration class (get_ort_quantizer_kwargs())
# that returns the dictionary of arguments to pass to the quantizer factory depending on the ort version
quantizer_kwargs = {
"model": onnx_model,
"static": quantization_config.is_static,
"per_channel": quantization_config.per_channel,
"mode": quantization_config.mode,
"weight_qType": quantization_config.weights_dtype,
"input_qType": quantization_config.activations_dtype,
"tensors_range": calibration_tensors_range,
"reduce_range": quantization_config.reduce_range,
"nodes_to_quantize": quantization_config.nodes_to_quantize,
"nodes_to_exclude": quantization_config.nodes_to_exclude,
"op_types_to_quantize": [
operator.value if isinstance(operator, ORTQuantizableOperator) else operator
for operator in quantization_config.operators_to_quantize
],
"extra_options": {
"WeightSymmetric": quantization_config.weights_symmetric,
"ActivationSymmetric": quantization_config.activations_symmetric,
"EnableSubgraph": has_subgraphs,
"ForceSymmetric": quantization_config.activations_symmetric and quantization_config.weights_symmetric,
"AddQDQPairToWeight": quantization_config.qdq_add_pair_to_weight,
"DedicatedQDQPair": quantization_config.qdq_dedicated_pair,
"QDQOpTypePerChannelSupportToAxis": quantization_config.qdq_op_type_per_channel_support_to_axis,
},
}
if use_qdq:
quantizer_kwargs.pop("mode")
if parse(ort_version) >= Version("1.18.0"):
# The argument `static` has been removed from the qdq quantizer factory in ORT 1.18
quantizer_kwargs.pop("static")
if parse(ort_version) >= Version("1.13.0"):
# The argument `input_qType` has been changed into `activation_qType` in ORT 1.13
quantizer_kwargs["activation_qType"] = quantizer_kwargs.pop("input_qType")
quantizer = quantizer_factory(**quantizer_kwargs)
LOGGER.info("Quantizing model...")
quantizer.quantize_model()
suffix = f"_{file_suffix}" if file_suffix else ""
quantized_model_path = save_dir.joinpath(f"{self.onnx_model_path.stem}{suffix}").with_suffix(".onnx")
LOGGER.info(f"Saving quantized model at: {save_dir} (external data format: " f"{use_external_data_format})")
quantizer.model.save_model_to_file(quantized_model_path.as_posix(), use_external_data_format)
# Create and save the configuration summarizing all the parameters related to quantization
ort_config = ORTConfig(quantization=quantization_config, use_external_data_format=use_external_data_format)
ort_config.save_pretrained(save_dir)
if self.config is not None:
self.config.save_pretrained(save_dir)
maybe_save_preprocessors(self.onnx_model_path.parent, save_dir)
return Path(save_dir)
def get_calibration_dataset(
self,
dataset_name: str,
num_samples: int = 100,
dataset_config_name: Optional[str] = None,
dataset_split: Optional[str] = None,
preprocess_function: Optional[Callable] = None,
preprocess_batch: bool = True,
seed: int = 2016,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
) -> "Dataset":
"""
Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
Args:
dataset_name (`str`):
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
to load to use for the calibration step.
num_samples (`int`, defaults to 100):
The maximum number of samples composing the calibration dataset.
dataset_config_name (`Optional[str]`, defaults to `None`):
The name of the dataset configuration.
dataset_split (`Optional[str]`, defaults to `None`):
Which split of the dataset to use to perform the calibration step.
preprocess_function (`Optional[Callable]`, defaults to `None`):
Processing function to apply to each example after loading dataset.
preprocess_batch (`bool`, defaults to `True`):
Whether the `preprocess_function` should be batched.
seed (`int`, defaults to 2016):
The random seed to use when shuffling the calibration dataset.
use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`):
Deprecated. Please use the `token` argument instead.
token (`Optional[Union[bool,str]]`, defaults to `None`):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
Returns:
The calibration `datasets.Dataset` to use for the post-training static quantization calibration
step.
"""
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
if dataset_name is None:
raise ValueError(
"ORTQuantizer: Static quantization calibration step requires a dataset_name if no calib_dataset is "
"provided."
)
requires_backends(self, ["datasets"])
from datasets import load_dataset
calib_dataset = load_dataset(
dataset_name,
name=dataset_config_name,
split=dataset_split,
token=token,
)
if num_samples is not None:
num_samples = min(num_samples, len(calib_dataset))
calib_dataset = calib_dataset.shuffle(seed=seed).select(range(num_samples))
if preprocess_function is not None:
processed_calib_dataset = calib_dataset.map(preprocess_function, batched=preprocess_batch)
else:
processed_calib_dataset = calib_dataset
return self.clean_calibration_dataset(processed_calib_dataset)
def clean_calibration_dataset(self, dataset: "Dataset") -> "Dataset":
model = onnx.load(self.onnx_model_path)
model_inputs = {input.name for input in model.graph.input}
ignored_columns = list(set(dataset.column_names) - model_inputs)
return dataset.remove_columns(ignored_columns)