in tensorflow_similarity/indexer.py [0:0]
def calibrate(self,
predictions: FloatTensor,
target_labels: Sequence[int],
thresholds_targets: MutableMapping[str, float],
calibration_metric: Union[str, ClassificationMetric] = "f1_score", # noqa
k: int = 1,
matcher: Union[str, ClassificationMatch] = 'match_nearest',
extra_metrics: Sequence[Union[str, ClassificationMetric]] = ['precision', 'recall'], # noqa
rounding: int = 2,
verbose: int = 1) -> CalibrationResults:
"""Calibrate model thresholds using a test dataset.
FIXME: more detailed explanation.
Args:
predictions: TF similarity model predictions, may be a multi-headed
output.
target_labels: Sequence of the expected labels associated with the
embedded queries.
thresholds_targets: Dict of performance targets to (if possible)
meet with respect to the `calibration_metric`.
calibration_metric: [ClassificationMetric()](metrics/overview.md)
used to evaluate the performance of the index.
k: How many neighbors to use during the calibration.
Defaults to 1.
matcher: {'match_nearest', 'match_majority_vote'} or
ClassificationMatch object. Defines the classification matching,
e.g., match_nearest will count a True Positive if the query_label
is equal to the label of the nearest neighbor and the distance is
less than or equal to the distance threshold.
Defaults to 'match_nearest'.
extra_metrics: List of additional
`tf.similarity.classification_metrics.ClassificationMetric()` to
compute and report. Defaults to ['precision', 'recall'].
rounding: Metric rounding. Default to 2 digits.
verbose: Be verbose and display calibration results. Defaults to 1.
Returns:
CalibrationResults containing the thresholds and cutpoints Dicts.
"""
# find NN
lookups = self.batch_lookup(predictions, k=k, verbose=verbose)
# making sure our metrics are all ClassificationMetric objects
calibration_metric = make_classification_metric(calibration_metric)
combined_metrics: List[ClassificationMetric] = (
[make_classification_metric(m) for m in extra_metrics])
# running calibration
calibration_results = self.evaluator.calibrate(
target_labels=target_labels,
lookups=lookups,
thresholds_targets=thresholds_targets,
calibration_metric=calibration_metric,
matcher=matcher,
extra_metrics=combined_metrics,
metric_rounding=rounding,
verbose=verbose
)
# display cutpoint results if requested
if verbose:
headers = ['name', 'value', 'distance'] # noqa
cutpoints = list(calibration_results.cutpoints.values())
# dynamically find which metrics we need. We only need to look at
# the first cutpoints dictionary as all subsequent ones will have
# the same metric keys.
for metric_name in cutpoints[0].keys():
if metric_name not in headers:
headers.append(metric_name)
rows = []
for data in cutpoints:
rows.append([data[v] for v in headers])
print("\n", tabulate(rows, headers=headers))
# store info for serialization purpose
self.is_calibrated = True
self.calibration_metric = calibration_metric
self.cutpoints = calibration_results.cutpoints
self.calibration_thresholds = calibration_results.thresholds
return calibration_results