def translate_keras_optimizer()

in tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py [0:0]


def translate_keras_optimizer(optimizer):
  """Translates a Keras optimizer to the tf.tpu.experimental.embedding version.

  Note that Keras optimizer params can accept Tensors or callables, whereas
  tpu_embedding optimizer params require floats. We call .get_config() on the
  Keras params, which evaluates each param immediately.

  NOTE: that the underlying Keras optimizer passed in will be used to create the
  slot variables for the embedding tables this optimizer is used for.

  Args:
    optimizer: A Keras optimizer parameter object.

  Raises:
    ValueError: if passed a Keras optimizer defining parameters unsupported by
        the corresponding tpu_embedding object, or an unsupported Keras
        optimizer.

  Returns:
    the tpu_embedding parameter object corresponding to optimizer.
  """

  if optimizer.__class__ in _OPTIMIZER_PARAMETERS:
    embedding_optimizer, supported, unsupported = (
        _OPTIMIZER_PARAMETERS[optimizer.__class__])
    config = optimizer.get_config()
    # We need to handle learning_rate specially so that we can properly support
    # dynamic learning rate. Depending on what the user passed for learning_rate
    # get_config does a few different things:
    # 1.  If it was a function, it calls the function (which we do not want, as
    #     we want to call the function in the strategy context so that all
    #     ops in the function are placed on the TPU). In this case the return
    #     type should generally be a tensor.
    # 2.  If it was a LearningRateSchedule, get_config calls
    #     serialize_keras_object on the schedule object. In this case the return
    #     type is a dict.
    # 3.  A python numeric constant or something convertible to one.
    if isinstance(config["learning_rate"], tf.Tensor):
      config["learning_rate"] = lambda: optimizer.get_config()["learning_rate"]
    elif isinstance(config["learning_rate"], dict):
      schedule = tf.keras.optimizers.schedules.deserialize(
          config["learning_rate"])
      config["learning_rate"] = lambda: schedule(optimizer.iterations)

    # Check to make sure only support params are set?
    _ensure_unsupported_params_unchanged(optimizer, supported, unsupported)

    params = {k: config[k] for k in supported}
    # If the optimizer has slots, add the slot variable creation fn.
    if optimizer.__class__ in _SLOT_NAME_MAPPING:
      params["slot_variable_creation_fn"] = _get_slot_variable_creation_fn(
          optimizer)

    return embedding_optimizer(**params)

  elif isinstance(optimizer, tf.keras.optimizers.Optimizer):
    raise ValueError("Keras optimizer %s is unsupported for TPU Embedding."
                     % optimizer.__class__.__name__)
  else:
    raise ValueError("%s is an unsupported optimizer class. Please pass a "
                     "Keras optimizer." % optimizer.__class__.__name__)