def call()

in tensorflow_recommenders/layers/factorized_top_k.py [0:0]


  def call(self,
           queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
           k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:

    k = k if k is not None else self._k

    if self._serialized_searcher is None:
      raise ValueError("The `index` method must be called first to "
                       "create the retrieval index.")

    searcher = scann_ops.searcher_from_module(self._serialized_searcher)

    if self.query_model is not None:
      queries = self.query_model(queries)

    if not isinstance(queries, tf.Tensor):
      raise ValueError(f"Queries must be a tensor, got {type(queries)}.")

    if len(queries.shape) == 2:
      if self._parallelize_batch_searches:
        result = searcher.search_batched_parallel(
            queries, final_num_neighbors=k)
      else:
        result = searcher.search_batched(queries, final_num_neighbors=k)
      indices = result.indices
      distances = result.distances
    elif len(queries.shape) == 1:
      result = searcher.search(queries, final_num_neighbors=k)
      indices = result.index
      distances = result.distance
    else:
      raise ValueError(
          f"Queries must be of rank 2 or 1, got {len(queries.shape)}.")

    if self._identifiers is None:
      return distances, indices

    return distances, tf.gather(self._identifiers, indices)