in tensorflow_recommenders/layers/factorized_top_k.py [0:0]
def __init__(self,
query_model: Optional[tf.keras.Model] = None,
k: int = 10,
distance_measure: Text = "dot_product",
num_leaves: int = 100,
num_leaves_to_search: int = 10,
training_iterations: int = 12,
dimensions_per_block: int = 2,
num_reordering_candidates: Optional[int] = None,
parallelize_batch_searches: bool = True,
name: Optional[Text] = None):
"""Initializes the layer.
Args:
query_model: Optional Keras model for representing queries. If provided,
will be used to transform raw features into query embeddings when
querying the layer. If not provided, the layer will expect to be given
query embeddings as inputs.
k: Default number of results to retrieve. Can be overridden in `call`.
distance_measure: Distance metric to use.
num_leaves: Number of leaves.
num_leaves_to_search: Number of leaves to search.
training_iterations: Number of training iterations when performing tree
building.
dimensions_per_block: Controls the dataset compression ratio. A higher
number results in greater compression, leading to faster scoring but
less accuracy and more memory usage.
num_reordering_candidates: If set, the index will perform a final
refinement pass on `num_reordering_candidates` candidates after
retrieving an initial set of neighbours. This helps improve accuracy,
but requires the original representations to be kept, and so will
increase the final model size."
parallelize_batch_searches: Whether batch querying should be done in
parallel.
name: Name of the layer.
Raises:
ImportError: if the scann library is not installed.
"""
super().__init__(k=k, name=name)
if not _HAVE_SCANN:
raise ImportError(
"The scann library is not present. Please install it using "
"`pip install scann` to use the ScaNN layer.")
self.query_model = query_model
self._k = k
self._parallelize_batch_searches = parallelize_batch_searches
self._num_reordering_candidates = num_reordering_candidates
self._training_iterations = training_iterations
self._identifiers = None
def build_searcher(candidates):
builder = scann_ops.builder(
db=candidates,
num_neighbors=self._k,
distance_measure=distance_measure)
builder = builder.tree(
num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_iterations=self._training_iterations,
)
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
if self._num_reordering_candidates is not None:
builder = builder.reorder(self._num_reordering_candidates)
# Set a unique name to prevent unintentional sharing between
# ScaNN instances.
return builder.build(shared_name=f"{self.name}/{uuid.uuid4()}")
self._build_searcher = build_searcher
self._serialized_searcher = None