in tensorflow_recommenders/layers/factorized_top_k.py [0:0]
def index(
self,
candidates: tf.Tensor,
identifiers: Optional[tf.Tensor] = None) -> "ScaNN":
if len(candidates.shape) != 2:
raise ValueError(
f"The candidates tensor must be 2D (got {candidates.shape}).")
if identifiers is not None and candidates.shape[0] != identifiers.shape[0]:
raise ValueError(
"The candidates and identifiers tensors must have the same number of rows "
f"(got {candidates.shape[0]} candidates rows and {identifiers.shape[0]} "
"identifier rows). "
)
self._serialized_searcher = self._build_searcher(
candidates).serialize_to_module()
if identifiers is not None:
# We need any value that has the correct dtype.
identifiers_initial_value = tf.zeros((), dtype=identifiers.dtype)
self._identifiers = self.add_weight(
name="identifiers",
dtype=identifiers.dtype,
shape=identifiers.shape,
initializer=tf.keras.initializers.Constant(
value=identifiers_initial_value),
trainable=False)
self._identifiers.assign(identifiers)
self._reset_tf_function_cache()
return self