def build()

in tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py [0:0]


  def build(self, input_shape):
    super().build(input_shape)

    self._strategy = tf.distribute.get_strategy()
    self._using_tpu = _is_tpu_strategy(self._strategy)

    if self._batch_size is None and self._using_tpu:
      self._batch_size = _get_batch_size_from_input_shapes(input_shape)

    self._tpu_embedding.build(self._batch_size)

    if self._using_tpu:
      # Note that self.tpu_embedding_helper_dummy matches _DUMMY_NAME above,
      # or it will appear twice in the list of saveables. Note that the Python
      # variable name should be _DUMMY_NAME too, as it is used to name internal
      # objects: we enforce that by creating it with setattr.
      setattr(
          self, _DUMMY_NAME,
          self.add_weight(
              name=_DUMMY_NAME,
              shape=(1,),
              initializer=tf.zeros_initializer(),
              trainable=True,
              dtype=tf.float32))
    else:
      # When on CPU, ensure that the embedding tables are part of the trainable
      # variables list for this layer.
      for _, weight in self._tpu_embedding.embedding_tables.items():
        self._trainable_weights.append(weight)