in tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py [0:0]
def build(self, input_shape):
super().build(input_shape)
self._strategy = tf.distribute.get_strategy()
self._using_tpu = _is_tpu_strategy(self._strategy)
if self._batch_size is None and self._using_tpu:
self._batch_size = _get_batch_size_from_input_shapes(input_shape)
self._tpu_embedding.build(self._batch_size)
if self._using_tpu:
# Note that self.tpu_embedding_helper_dummy matches _DUMMY_NAME above,
# or it will appear twice in the list of saveables. Note that the Python
# variable name should be _DUMMY_NAME too, as it is used to name internal
# objects: we enforce that by creating it with setattr.
setattr(
self, _DUMMY_NAME,
self.add_weight(
name=_DUMMY_NAME,
shape=(1,),
initializer=tf.zeros_initializer(),
trainable=True,
dtype=tf.float32))
else:
# When on CPU, ensure that the embedding tables are part of the trainable
# variables list for this layer.
for _, weight in self._tpu_embedding.embedding_tables.items():
self._trainable_weights.append(weight)