in tensorflow_similarity/indexer.py [0:0]
def match(self,
predictions: FloatTensor,
no_match_label: int = -1,
k=1,
matcher: Union[str, ClassificationMatch] = 'match_nearest',
verbose: int = 1) -> Dict[str, List[int]]:
"""Match embeddings against the various cutpoints thresholds
Args:
predictions: TF similarity model predictions, may be a multi-headed
output.
no_match_label: What label value to assign when there is no match.
Defaults to -1.
k: How many neighboors to use during the calibration.
Defaults to 1.
matcher: {'match_nearest', 'match_majority_vote'} or
ClassificationMatch object. Defines the classification matching,
e.g., match_nearest will count a True Positive if the query_label
is equal to the label of the nearest neighbor and the distance is
less than or equal to the distance threshold.
verbose: display progression. Default to 1.
Notes:
1. It is up to the [`SimilarityModel.match()`](similarity_model.md)
code to decide which of cutpoints results to use / show to the
users. This function returns all of them as there is little
performance downside to do so and it makes the code clearer
and simpler.
2. The calling function is responsible to return the list of class
matched to allows implementation to use additional criteria if they
choose to.
Returns:
Dict of cutpoint names mapped to lists of matches.
"""
matcher = make_classification_matcher(matcher)
lookups = self.batch_lookup(predictions, k=k, verbose=verbose)
lookup_distances = unpack_lookup_distances(lookups)
lookup_labels = unpack_lookup_labels(lookups)
if verbose:
pb = tqdm(total=len(lookup_distances) * len(self.cutpoints),
desc='matching embeddings')
matches: DefaultDict[str, List[int]] = defaultdict(list)
for cp_name, cp_data in self.cutpoints.items():
distance_threshold = float(cp_data['distance'])
pred_labels, pred_dist = matcher.derive_match(
lookup_labels=lookup_labels,
lookup_distances=lookup_distances)
for label, distance in zip(pred_labels, pred_dist):
if distance <= distance_threshold:
label = int(label)
else:
label = no_match_label
matches[cp_name].append(label)
if verbose:
pb.update()
if verbose:
pb.close()
return matches