def __init__()

in tensorflow_recommenders/layers/factorized_top_k.py [0:0]


  def __init__(self,
               query_model: Optional[tf.keras.Model] = None,
               k: int = 10,
               distance_measure: Text = "dot_product",
               num_leaves: int = 100,
               num_leaves_to_search: int = 10,
               training_iterations: int = 12,
               dimensions_per_block: int = 2,
               num_reordering_candidates: Optional[int] = None,
               parallelize_batch_searches: bool = True,
               name: Optional[Text] = None):
    """Initializes the layer.

    Args:
      query_model: Optional Keras model for representing queries. If provided,
        will be used to transform raw features into query embeddings when
        querying the layer. If not provided, the layer will expect to be given
        query embeddings as inputs.
      k: Default number of results to retrieve. Can be overridden in `call`.
      distance_measure: Distance metric to use.
      num_leaves: Number of leaves.
      num_leaves_to_search: Number of leaves to search.
      training_iterations: Number of training iterations when performing tree
        building.
      dimensions_per_block: Controls the dataset compression ratio. A higher
        number results in greater compression, leading to faster scoring but
        less accuracy and more memory usage.
      num_reordering_candidates: If set, the index will perform a final
        refinement pass on `num_reordering_candidates` candidates after
        retrieving an initial set of neighbours. This helps improve accuracy,
        but requires the original representations to be kept, and so will
        increase the final model size."
      parallelize_batch_searches: Whether batch querying should be done in
        parallel.
      name: Name of the layer.

    Raises:
      ImportError: if the scann library is not installed.
    """

    super().__init__(k=k, name=name)

    if not _HAVE_SCANN:
      raise ImportError(
          "The scann library is not present. Please install it using "
          "`pip install scann` to use the ScaNN layer.")

    self.query_model = query_model
    self._k = k
    self._parallelize_batch_searches = parallelize_batch_searches
    self._num_reordering_candidates = num_reordering_candidates
    self._training_iterations = training_iterations
    self._identifiers = None

    def build_searcher(candidates):
      builder = scann_ops.builder(
          db=candidates,
          num_neighbors=self._k,
          distance_measure=distance_measure)

      builder = builder.tree(
          num_leaves=num_leaves,
          num_leaves_to_search=num_leaves_to_search,
          training_iterations=self._training_iterations,
      )
      builder = builder.score_ah(dimensions_per_block=dimensions_per_block)

      if self._num_reordering_candidates is not None:
        builder = builder.reorder(self._num_reordering_candidates)

      # Set a unique name to prevent unintentional sharing between
      # ScaNN instances.
      return builder.build(shared_name=f"{self.name}/{uuid.uuid4()}")

    self._build_searcher = build_searcher
    self._serialized_searcher = None