def batch_lookup()

in tensorflow_similarity/indexer.py [0:0]


    def batch_lookup(self,
                     predictions: FloatTensor,
                     k: int = 5,
                     verbose: int = 1) -> List[List[Lookup]]:

        """Find the k closest matches for a set of embeddings

        Args:
            predictions: TF similarity model predictions, may be a multi-headed
            output.

            k: Number of nearest neighbors to lookup. Defaults to 5.

            verbose: Be verbose. Defaults to 1.

        Returns
            list of list of k nearest neighbors:
            List[List[Lookup]]
        """

        embeddings = self._get_embeddings(predictions)
        num_embeddings = len(embeddings)
        start = time()
        batch_lookups = []

        if verbose:
            print("\nPerforming NN search\n")
        batch_idxs, batch_distances = (
                self.search.batch_lookup(embeddings, k=k))

        if verbose:
            pb = tqdm(total=num_embeddings, desc='Building NN list')
        for eidx in range(num_embeddings):
            lidxs = batch_idxs[eidx]   # list of nn idxs
            distances = batch_distances[eidx]

            nn_embeddings, labels, data = self.kv_store.batch_get(lidxs)
            lookups = []
            for i in range(len(nn_embeddings)):
                # ! casting is needed to avoid slowness down the line
                lookups.append(Lookup(
                    rank=i + 1,
                    embedding=nn_embeddings[i],
                    distance=float(distances[i]),
                    label=self._cast_label(labels[i]),
                    data=data[i]
                ))
            batch_lookups.append(lookups)

            if verbose:
                pb.update()

        if verbose:
            pb.close()

        # stats
        lookup_time = time() - start
        per_lookup_time = lookup_time / num_embeddings
        for _ in range(num_embeddings):
            self._lookup_timings_buffer.append(per_lookup_time)
        self._stats['num_lookups'] += num_embeddings

        return batch_lookups