in tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py [0:0]
def translate_keras_optimizer(optimizer):
"""Translates a Keras optimizer to the tf.tpu.experimental.embedding version.
Note that Keras optimizer params can accept Tensors or callables, whereas
tpu_embedding optimizer params require floats. We call .get_config() on the
Keras params, which evaluates each param immediately.
NOTE: that the underlying Keras optimizer passed in will be used to create the
slot variables for the embedding tables this optimizer is used for.
Args:
optimizer: A Keras optimizer parameter object.
Raises:
ValueError: if passed a Keras optimizer defining parameters unsupported by
the corresponding tpu_embedding object, or an unsupported Keras
optimizer.
Returns:
the tpu_embedding parameter object corresponding to optimizer.
"""
if optimizer.__class__ in _OPTIMIZER_PARAMETERS:
embedding_optimizer, supported, unsupported = (
_OPTIMIZER_PARAMETERS[optimizer.__class__])
config = optimizer.get_config()
# We need to handle learning_rate specially so that we can properly support
# dynamic learning rate. Depending on what the user passed for learning_rate
# get_config does a few different things:
# 1. If it was a function, it calls the function (which we do not want, as
# we want to call the function in the strategy context so that all
# ops in the function are placed on the TPU). In this case the return
# type should generally be a tensor.
# 2. If it was a LearningRateSchedule, get_config calls
# serialize_keras_object on the schedule object. In this case the return
# type is a dict.
# 3. A python numeric constant or something convertible to one.
if isinstance(config["learning_rate"], tf.Tensor):
config["learning_rate"] = lambda: optimizer.get_config()["learning_rate"]
elif isinstance(config["learning_rate"], dict):
schedule = tf.keras.optimizers.schedules.deserialize(
config["learning_rate"])
config["learning_rate"] = lambda: schedule(optimizer.iterations)
# Check to make sure only support params are set?
_ensure_unsupported_params_unchanged(optimizer, supported, unsupported)
params = {k: config[k] for k in supported}
# If the optimizer has slots, add the slot variable creation fn.
if optimizer.__class__ in _SLOT_NAME_MAPPING:
params["slot_variable_creation_fn"] = _get_slot_variable_creation_fn(
optimizer)
return embedding_optimizer(**params)
elif isinstance(optimizer, tf.keras.optimizers.Optimizer):
raise ValueError("Keras optimizer %s is unsupported for TPU Embedding."
% optimizer.__class__.__name__)
else:
raise ValueError("%s is an unsupported optimizer class. Please pass a "
"Keras optimizer." % optimizer.__class__.__name__)