def match()

in tensorflow_similarity/indexer.py [0:0]


    def match(self,
              predictions: FloatTensor,
              no_match_label: int = -1,
              k=1,
              matcher: Union[str, ClassificationMatch] = 'match_nearest',
              verbose: int = 1) -> Dict[str, List[int]]:
        """Match embeddings against the various cutpoints thresholds

        Args:
            predictions: TF similarity model predictions, may be a multi-headed
            output.

            no_match_label: What label value to assign when there is no match.
            Defaults to -1.

            k: How many neighboors to use during the calibration.
            Defaults to 1.

            matcher: {'match_nearest', 'match_majority_vote'} or
            ClassificationMatch object. Defines the classification matching,
            e.g., match_nearest will count a True Positive if the query_label
            is equal to the label of the nearest neighbor and the distance is
            less than or equal to the distance threshold.

            verbose: display progression. Default to 1.

        Notes:

            1. It is up to the [`SimilarityModel.match()`](similarity_model.md)
            code to decide which of cutpoints results to use / show to the
            users. This function returns all of them as there is little
            performance downside to do so and it makes the code clearer
            and simpler.

            2. The calling function is responsible to return the list of class
            matched to allows implementation to use additional criteria if they
            choose to.

        Returns:
            Dict of cutpoint names mapped to lists of matches.
        """
        matcher = make_classification_matcher(matcher)

        lookups = self.batch_lookup(predictions, k=k, verbose=verbose)

        lookup_distances = unpack_lookup_distances(lookups)
        lookup_labels = unpack_lookup_labels(lookups)

        if verbose:
            pb = tqdm(total=len(lookup_distances) * len(self.cutpoints),
                      desc='matching embeddings')

        matches: DefaultDict[str, List[int]] = defaultdict(list)
        for cp_name, cp_data in self.cutpoints.items():
            distance_threshold = float(cp_data['distance'])

            pred_labels, pred_dist = matcher.derive_match(
                    lookup_labels=lookup_labels,
                    lookup_distances=lookup_distances)

            for label, distance in zip(pred_labels, pred_dist):
                if distance <= distance_threshold:
                    label = int(label)
                else:
                    label = no_match_label

                matches[cp_name].append(label)

                if verbose:
                    pb.update()

        if verbose:
            pb.close()

        return matches