in tensorflow_recommenders/layers/factorized_top_k.py [0:0]
def call(self,
queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
k = k if k is not None else self._k
if self._serialized_searcher is None:
raise ValueError("The `index` method must be called first to "
"create the retrieval index.")
searcher = scann_ops.searcher_from_module(self._serialized_searcher)
if self.query_model is not None:
queries = self.query_model(queries)
if not isinstance(queries, tf.Tensor):
raise ValueError(f"Queries must be a tensor, got {type(queries)}.")
if len(queries.shape) == 2:
if self._parallelize_batch_searches:
result = searcher.search_batched_parallel(
queries, final_num_neighbors=k)
else:
result = searcher.search_batched(queries, final_num_neighbors=k)
indices = result.indices
distances = result.distances
elif len(queries.shape) == 1:
result = searcher.search(queries, final_num_neighbors=k)
indices = result.index
distances = result.distance
else:
raise ValueError(
f"Queries must be of rank 2 or 1, got {len(queries.shape)}.")
if self._identifiers is None:
return distances, indices
return distances, tf.gather(self._identifiers, indices)