def DynamicEmbeddingOptimizer()

in tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py [0:0]


def DynamicEmbeddingOptimizer(self, bp_v2=None):
  """ An optimizer wrapper to make any TensorFlow optimizer capable of training
  Dynamic Embeddding Variables.

  Args:
    self: a TensorFlow optimizer.
    bp_v2: By default is None, If None use params_var_.bp_v2 setting
        (see `tfra.dynamic_embedding_variable.get_variable`)
  Example usage:

    ```python
    optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(
        tf.train.AdamOptimizer(0.001))
    ```

  Returns:
    The optimizer itself but has ability to train Dynamic Embedding Variables.
  """
  self._bp_v2 = bp_v2

  def _distributed_apply(distribution, grads_and_vars, name, apply_state):
    """`apply_gradients` using a `DistributionStrategy`."""

    def apply_grad_to_update_var(var, grad):
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)

      apply_kwargs = {}
      if not isinstance(var, de.TrainableWrapper):
        if isinstance(grad, ops.IndexedSlices):
          if var.constraint is not None:
            raise RuntimeError(
                "Cannot use a constraint function on a sparse variable.")
          if "apply_state" in self._sparse_apply_args:
            apply_kwargs["apply_state"] = apply_state
          return self._resource_apply_sparse_duplicate_indices(
              grad.values, var, grad.indices, **apply_kwargs)

        if "apply_state" in self._dense_apply_args:
          apply_kwargs["apply_state"] = apply_state
        update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
        if var.constraint is not None:
          with ops.control_dependencies([update_op]):
            return var.assign(var.constraint(var))
        else:
          return update_op
      else:
        with ops.colocate_with(None, ignore_existing=True):
          _slots = [self.get_slot(var, _s) for _s in self.get_slot_names()]
          # Add the optimizer slots to restricting list.
          var._track_optimizer_slots(_slots)

          with ops.control_dependencies([grad]):
            v0 = var.read_value(do_prefetch=not var.params.bp_v2)
            s0 = [_s.read_value() for _s in _slots]
            _before = [v0] + s0

          if isinstance(grad, ops.IndexedSlices):
            if var.constraint is not None:
              raise RuntimeError(
                  "Cannot use a constraint function on a sparse variable.")
            if "apply_state" in self._sparse_apply_args:
              apply_kwargs["apply_state"] = apply_state
            with ops.control_dependencies(_before):
              _apply_op = self._resource_apply_sparse_duplicate_indices(
                  grad.values, var, grad.indices, **apply_kwargs)
            with ops.control_dependencies([_apply_op]):
              _after = control_flow_ops.group(
                  [var.update_op(v0=v0)] +
                  [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
              return _after

          if "apply_state" in self._dense_apply_args:
            apply_kwargs["apply_state"] = apply_state
          with ops.control_dependencies(_before):
            update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
          if var.constraint is not None:
            with ops.control_dependencies([update_op]):
              return var.assign(var.constraint(var))
          else:
            with ops.control_dependencies([update_op]):
              _after = control_flow_ops.group(
                  [var.update_op(v0=v0)] +
                  [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
            return _after

    update_ops = []
    with backend.name_scope(name or self._name):
      for grad, var in grads_and_vars:
        scope_name = ("update" if ops.executing_eagerly_outside_functions() else
                      "update_" + var.op.name)
        # Colocate the update with variables to avoid unnecessary communication
        # delays. See b/136304694.
        with backend.name_scope(
            scope_name), distribution.extended.colocate_vars_with(var):
          update_ops.extend(
              distribution.extended.update(var,
                                           apply_grad_to_update_var,
                                           args=(grad,),
                                           group=False))

      any_symbolic = any(
          isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i)
          for i in update_ops)
      if not context.executing_eagerly() or any_symbolic:
        # If the current context is graph mode or any of the update ops are
        # symbolic then the step update should be carried out under a graph
        # context. (eager updates execute immediately)
        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
          with ops.control_dependencies(update_ops):
            return self._iterations.assign_add(1).op

      return self._iterations.assign_add(1)

  def add_slot(var, slot_name, initializer="zeros", shape=None):
    """Add a new slot variable for `var`."""
    if slot_name not in self._slot_names:
      self._slot_names.append(slot_name)
    var_key = optimizer_v2._var_key(var)
    slot_dict = self._slots.setdefault(var_key, {})
    weight = slot_dict.get(slot_name, None)
    if weight is None:
      if isinstance(initializer, six.string_types) or callable(initializer):
        initializer = initializers.get(initializer)
        if isinstance(
            initializer,
            trackable.CheckpointInitialValueCallable) or (shape is not None):
          slot_shape = shape
        else:
          slot_shape = var.shape
        initial_value = functools.partial(initializer,
                                          shape=slot_shape,
                                          dtype=var.dtype)
      else:
        initial_value = initializer
      strategy = distribute_ctx.get_strategy()
      with strategy.extended.colocate_vars_with(var):
        if isinstance(var, de.TrainableWrapper):
          weight = de.create_slots(var, initial_value, slot_name,
                                   var._shared_name, self._bp_v2)
        else:
          weight = variables.Variable(
              name="%s/%s" % (
                  var._shared_name,
                  slot_name,
              ),  # pylint: disable=protected-access
              dtype=var.dtype,
              trainable=False,
              initial_value=initial_value,
          )
      backend.track_variable(weight)
      slot_dict[slot_name] = weight
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=weight)
      self._weights.append(weight)
    return weight

  def _get_or_make_slot(var, val, slot_name, op_name):
    """Find or create a slot for a variable.

        Args:
          var: A `Variable` object.
          val: A `Tensor`.  The initial value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, val, slot_name, op_name,
                                            self._bp_v2)
      else:
        new_slot_variable = slot_creator.create_slot(var, val, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]

  def _get_or_make_slot_with_initializer(var, initializer, shape, dtype,
                                         slot_name, op_name):
    """Find or create a slot for a variable, using an Initializer.

        Args:
          var: A `Variable` object.
          initializer: An `Initializer`.  The initial value of the slot.
          shape: Shape of the initial value of the slot.
          dtype: Type of the value of the slot.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, initializer, slot_name,
                                            op_name, self._bp_v2)
      else:
        new_slot_variable = slot_creator.create_slot_with_initializer(
            var, initializer, shape, dtype, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]

  def _zeros_slot(var, slot_name, op_name):
    """Find or create a slot initialized with 0.0.

        Args:
          var: A `Variable` object.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
    named_slots = self._slot_dict(slot_name)
    if optimizer._var_key(var) not in named_slots:
      if isinstance(var, de.TrainableWrapper):
        new_slot_variable = de.create_slots(var, 0.0, slot_name, op_name,
                                            self._bp_v2)
      else:
        new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
      self._restore_slot_variable(slot_name=slot_name,
                                  variable=var,
                                  slot_variable=new_slot_variable)
      named_slots[optimizer._var_key(var)] = new_slot_variable
    return named_slots[optimizer._var_key(var)]

  if isinstance(self, optimizer.Optimizer):
    self._get_or_make_slot = _get_or_make_slot
    self._get_or_make_slot_with_initializer = _get_or_make_slot_with_initializer
    self._zeros_slot = _zeros_slot
  elif isinstance(self, optimizer_v2.OptimizerV2) or isinstance(
      self, keras_optimizer):
    self.add_slot = add_slot
    self._distributed_apply = _distributed_apply
  else:
    raise Exception("Optimizer type is not supported! got {}".format(
        str(type(self))))
  return self