from typing import TYPE_CHECKING, Dict, List

from ...runs_base import Calibrator
from .. import ORTQuantizer
from ..configuration import AutoCalibrationConfig, QuantizationConfig
from ..preprocessors import QuantizationPreprocessor
from ..preprocessors.passes import ExcludeGeLUNodes, ExcludeLayerNormNodes, ExcludeNodeAfter, ExcludeNodeFollowedBy


if TYPE_CHECKING:
    from datasets import Dataset


class OnnxRuntimeCalibrator(Calibrator):
    def __init__(
        self,
        calibration_dataset: "Dataset",
        quantizer: ORTQuantizer,
        model_path: str,
        qconfig: QuantizationConfig,
        calibration_params: Dict,
        node_exclusion: List[str],
    ):
        super().__init__(
            calibration_dataset=calibration_dataset,
            quantizer=quantizer,
            model_path=model_path,
            qconfig=qconfig,
            calibration_params=calibration_params,
            node_exclusion=node_exclusion,
        )

        # Remove the unnecessary columns of the calibration dataset before the calibration step
        self.calibration_dataset = self.quantizer.clean_calibration_dataset(calibration_dataset)

    def fit(self):
        # Create the calibration preprocessor excluding nodes
        quantization_preprocessor = QuantizationPreprocessor()

        if "layernorm" in self.node_exclusion:
            # Exclude the nodes constituting LayerNorm
            quantization_preprocessor.register_pass(ExcludeLayerNormNodes())
        if "gelu" in self.node_exclusion:
            # Exclude the nodes constituting GELU
            quantization_preprocessor.register_pass(ExcludeGeLUNodes())
        if "residual" in self.node_exclusion:
            # Exclude the residual connection Add nodes
            quantization_preprocessor.register_pass(ExcludeNodeAfter("Add", "Add"))
        if "gather" in self.node_exclusion:
            # Exclude the Add nodes following the Gather operator
            quantization_preprocessor.register_pass(ExcludeNodeAfter("Gather", "Add"))
        if "softmax" in self.node_exclusion:
            # Exclude the Add nodes followed by the Softmax operator
            quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax"))

        # Create the calibration configuration given the selected calibration method
        if self.calibration_params["method"] == "entropy":
            calibration_config = AutoCalibrationConfig.entropy(self.calibration_dataset)
        elif self.calibration_params["method"] == "percentile":
            calibration_config = AutoCalibrationConfig.percentiles(
                self.calibration_dataset,
                percentile=self.calibration_params["calibration_histogram_percentile"],
            )
        else:
            calibration_config = AutoCalibrationConfig.minmax(
                self.calibration_dataset,
                self.calibration_params["calibration_moving_average"],
                self.calibration_params["calibration_moving_average_constant"],
            )

        # TODO estimate memory needed for entropy/percentile to autochoose number of shards
        num_calibration_shards = 4
        if not 1 <= num_calibration_shards <= len(self.calibration_dataset):
            raise ValueError(
                f"Invalid value of number of shards {num_calibration_shards} chosen to split the calibration"
                " dataset, should be higher than 0 and lower or equal to the number of samples "
                f"{len(self.calibration_dataset)}."
            )

        for i in range(num_calibration_shards):
            shard = self.calibration_dataset.shard(num_calibration_shards, i)
            self.quantizer.partial_fit(
                dataset=shard,
                calibration_config=calibration_config,
                onnx_model_path=self.model_path,
                operators_to_quantize=self.qconfig.operators_to_quantize,
                batch_size=8,
                use_external_data_format=False,
            )
        ranges = self.quantizer.compute_ranges()

        return ranges, quantization_preprocessor
