def index()

in tensorflow_recommenders/layers/factorized_top_k.py [0:0]


  def index(
      self,
      candidates: tf.Tensor,
      identifiers: Optional[tf.Tensor] = None) -> "ScaNN":

    if len(candidates.shape) != 2:
      raise ValueError(
          f"The candidates tensor must be 2D (got {candidates.shape}).")

    if identifiers is not None and candidates.shape[0] != identifiers.shape[0]:
      raise ValueError(
          "The candidates and identifiers tensors must have the same number of rows "
          f"(got {candidates.shape[0]} candidates rows and {identifiers.shape[0]} "
          "identifier rows). "
      )

    self._serialized_searcher = self._build_searcher(
        candidates).serialize_to_module()

    if identifiers is not None:
      # We need any value that has the correct dtype.
      identifiers_initial_value = tf.zeros((), dtype=identifiers.dtype)
      self._identifiers = self.add_weight(
          name="identifiers",
          dtype=identifiers.dtype,
          shape=identifiers.shape,
          initializer=tf.keras.initializers.Constant(
              value=identifiers_initial_value),
          trainable=False)
      self._identifiers.assign(identifiers)

    self._reset_tf_function_cache()

    return self