optimum/onnxruntime/runs/calibrator.py (71 lines of code) (raw):

from typing import TYPE_CHECKING, Dict, List from ...runs_base import Calibrator from .. import ORTQuantizer from ..configuration import AutoCalibrationConfig, QuantizationConfig from ..preprocessors import QuantizationPreprocessor from ..preprocessors.passes import ExcludeGeLUNodes, ExcludeLayerNormNodes, ExcludeNodeAfter, ExcludeNodeFollowedBy if TYPE_CHECKING: from datasets import Dataset class OnnxRuntimeCalibrator(Calibrator): def __init__( self, calibration_dataset: "Dataset", quantizer: ORTQuantizer, model_path: str, qconfig: QuantizationConfig, calibration_params: Dict, node_exclusion: List[str], ): super().__init__( calibration_dataset=calibration_dataset, quantizer=quantizer, model_path=model_path, qconfig=qconfig, calibration_params=calibration_params, node_exclusion=node_exclusion, ) # Remove the unnecessary columns of the calibration dataset before the calibration step self.calibration_dataset = self.quantizer.clean_calibration_dataset(calibration_dataset) def fit(self): # Create the calibration preprocessor excluding nodes quantization_preprocessor = QuantizationPreprocessor() if "layernorm" in self.node_exclusion: # Exclude the nodes constituting LayerNorm quantization_preprocessor.register_pass(ExcludeLayerNormNodes()) if "gelu" in self.node_exclusion: # Exclude the nodes constituting GELU quantization_preprocessor.register_pass(ExcludeGeLUNodes()) if "residual" in self.node_exclusion: # Exclude the residual connection Add nodes quantization_preprocessor.register_pass(ExcludeNodeAfter("Add", "Add")) if "gather" in self.node_exclusion: # Exclude the Add nodes following the Gather operator quantization_preprocessor.register_pass(ExcludeNodeAfter("Gather", "Add")) if "softmax" in self.node_exclusion: # Exclude the Add nodes followed by the Softmax operator quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Create the calibration configuration given the selected calibration method if self.calibration_params["method"] == "entropy": calibration_config = AutoCalibrationConfig.entropy(self.calibration_dataset) elif self.calibration_params["method"] == "percentile": calibration_config = AutoCalibrationConfig.percentiles( self.calibration_dataset, percentile=self.calibration_params["calibration_histogram_percentile"], ) else: calibration_config = AutoCalibrationConfig.minmax( self.calibration_dataset, self.calibration_params["calibration_moving_average"], self.calibration_params["calibration_moving_average_constant"], ) # TODO estimate memory needed for entropy/percentile to autochoose number of shards num_calibration_shards = 4 if not 1 <= num_calibration_shards <= len(self.calibration_dataset): raise ValueError( f"Invalid value of number of shards {num_calibration_shards} chosen to split the calibration" " dataset, should be higher than 0 and lower or equal to the number of samples " f"{len(self.calibration_dataset)}." ) for i in range(num_calibration_shards): shard = self.calibration_dataset.shard(num_calibration_shards, i) self.quantizer.partial_fit( dataset=shard, calibration_config=calibration_config, onnx_model_path=self.model_path, operators_to_quantize=self.qconfig.operators_to_quantize, batch_size=8, use_external_data_format=False, ) ranges = self.quantizer.compute_ranges() return ranges, quantization_preprocessor