def __init__()

in tensorflow_recommenders/experimental/layers/embedding/partial_tpu_embedding.py [0:0]


  def __init__(self,
               feature_config,
               optimizer: tf.keras.optimizers.Optimizer,
               pipeline_execution_with_tensor_core: bool = False,
               batch_size: Optional[int] = None,
               size_threshold: Optional[int] = 10_000) -> None:
    """Initializes the embedding layer.

    Args:
      feature_config: A nested structure of
        `tf.tpu.experimental.embedding.FeatureConfig` configs.
      optimizer: An optimizer used for TPU embeddings.
      pipeline_execution_with_tensor_core: If True, the TPU embedding
        computations will overlap with the TensorCore computations (and hence
        will be one step old with potential correctness drawbacks). Set to True
        for improved performance.
      batch_size: If set, this will be used as the global batch size and
        override the autodetection of the batch size from the layer's input.
        This is necesarry if all inputs to the layer's call are SparseTensors.
      size_threshold: A threshold for table sizes below which a Keras embedding
        layer is used, and above which a TPU embedding layer is used.
        Set `size_threshold=0` to use TPU embedding for all tables and
        `size_threshold=None` to use only Keras embeddings.
    """
    super().__init__()

    tpu_feature_config = {}
    table_to_keras_emb = {}
    self._keras_embedding_layers = {}

    for name, embedding_feature_config in feature_config.items():
      table_config = embedding_feature_config.table
      if size_threshold is not None and table_config.vocabulary_size > size_threshold:
         # TPUEmbedding layer.
        tpu_feature_config[name] = embedding_feature_config
        continue

      # Keras layer.
      # Multiple features can reuse the same table.
      if table_config not in table_to_keras_emb:
        table_to_keras_emb[table_config] = tf.keras.layers.Embedding(
            input_dim=table_config.vocabulary_size,
            output_dim=table_config.dim)
      self._keras_embedding_layers[name] = table_to_keras_emb[table_config]

    self._tpu_embedding = TPUEmbedding(
        tpu_feature_config, optimizer) if tpu_feature_config else None