def calibrate()

in tensorflow_similarity/indexer.py [0:0]


    def calibrate(self,
                  predictions: FloatTensor,
                  target_labels: Sequence[int],
                  thresholds_targets: MutableMapping[str, float],
                  calibration_metric: Union[str, ClassificationMetric] = "f1_score",  # noqa
                  k: int = 1,
                  matcher: Union[str, ClassificationMatch] = 'match_nearest',
                  extra_metrics: Sequence[Union[str, ClassificationMetric]] = ['precision', 'recall'],  # noqa
                  rounding: int = 2,
                  verbose: int = 1) -> CalibrationResults:
        """Calibrate model thresholds using a test dataset.

        FIXME: more detailed explanation.

        Args:
            predictions: TF similarity model predictions, may be a multi-headed
            output.

            target_labels: Sequence of the expected labels associated with the
            embedded queries.

            thresholds_targets: Dict of performance targets to (if possible)
            meet with respect to the `calibration_metric`.

            calibration_metric: [ClassificationMetric()](metrics/overview.md)
            used to evaluate the performance of the index.

            k: How many neighbors to use during the calibration.
            Defaults to 1.

            matcher: {'match_nearest', 'match_majority_vote'} or
            ClassificationMatch object. Defines the classification matching,
            e.g., match_nearest will count a True Positive if the query_label
            is equal to the label of the nearest neighbor and the distance is
            less than or equal to the distance threshold.
            Defaults to 'match_nearest'.

            extra_metrics: List of additional
            `tf.similarity.classification_metrics.ClassificationMetric()` to
            compute and report. Defaults to ['precision', 'recall'].

            rounding: Metric rounding. Default to 2 digits.

            verbose: Be verbose and display calibration results. Defaults to 1.

        Returns:
            CalibrationResults containing the thresholds and cutpoints Dicts.
        """

        # find NN
        lookups = self.batch_lookup(predictions, k=k, verbose=verbose)

        # making sure our metrics are all ClassificationMetric objects
        calibration_metric = make_classification_metric(calibration_metric)

        combined_metrics: List[ClassificationMetric] = (
            [make_classification_metric(m) for m in extra_metrics])

        # running calibration
        calibration_results = self.evaluator.calibrate(
            target_labels=target_labels,
            lookups=lookups,
            thresholds_targets=thresholds_targets,
            calibration_metric=calibration_metric,
            matcher=matcher,
            extra_metrics=combined_metrics,
            metric_rounding=rounding,
            verbose=verbose
        )

        # display cutpoint results if requested
        if verbose:
            headers = ['name', 'value', 'distance']  # noqa
            cutpoints = list(calibration_results.cutpoints.values())
            # dynamically find which metrics we need. We only need to look at
            # the first cutpoints dictionary as all subsequent ones will have
            # the same metric keys.
            for metric_name in cutpoints[0].keys():
                if metric_name not in headers:
                    headers.append(metric_name)

            rows = []
            for data in cutpoints:
                rows.append([data[v] for v in headers])
            print("\n", tabulate(rows, headers=headers))

        # store info for serialization purpose
        self.is_calibrated = True
        self.calibration_metric = calibration_metric
        self.cutpoints = calibration_results.cutpoints
        self.calibration_thresholds = calibration_results.thresholds
        return calibration_results