def _resource_apply_sparse()

in tensorflow_addons/optimizers/yogi.py [0:0]


    def _resource_apply_sparse(self, grad, var, indices):
        """Applies sparse gradients to a variable.

        Args:
          grad: A tensor for the `values` of `tf.IndexedSlices`.
          var: A `tf.Variable` object.
          indices: A tensor for the `indices` of `tf.IndexedSlices`.
        Returns:
          An op which updates `var` with `grad` and `indices`.
        """

        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta1_t = self._get_hyper("beta_1", var_dtype)
        beta2_t = self._get_hyper("beta_2", var_dtype)
        epsilon_t = self._get_hyper("epsilon", var_dtype)
        l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
        l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)
        beta1_power = tf.pow(beta1_t, local_step)
        beta2_power = tf.pow(beta2_t, local_step)

        lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)

        update_vs = []
        if self._beta1 == 0.0:
            # v_t = v + sign(g_t^2-v)(g_t^2)
            v = self.get_slot(var, "v")
            grad2 = grad * grad
            v_slice = tf.gather(v, indices)
            if self._activation == "sign":
                sign = tf.sign(grad2 - v_slice)
            elif self._activation == "tanh":
                sign = tf.tanh(10 * (grad2 - v_slice))
            else:
                raise NotImplementedError("Activation function can be sign or tanh")
            v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
            v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
            v_sqrt = tf.sqrt(v_scaled_g_values)

            # Yogi effective LR
            per_coord_lr = lr / (v_sqrt + epsilon_t)

            # Variable update
            # Step 1: Gradient descent
            var_slice = tf.gather(var, indices)
            new_var = var_slice - per_coord_lr * grad
            # Step 2: Prox operator
            if self._l1_regularization_strength > 0:
                new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
            elif self._l2_regularization_strength > 0:
                new_var = new_var / (1 + l2_t * per_coord_lr)
            # Step 3: Update
            var_update = self._resource_scatter_update(var, indices, new_var)
            update_vs.append(var_update)
            update_vs.append(v_t)

        else:
            # m_t = beta1 * m + (1 - beta1) * g_t
            m = self.get_slot(var, "m")
            m_scaled_g_values = grad * (1 - beta1_t)
            m_t = m.assign(m * beta1_t, use_locking=self._use_locking)
            with tf.control_dependencies([m_t]):
                m_slice = tf.gather(m, indices) + m_scaled_g_values
                m_t = self._resource_scatter_update(m, indices, m_slice)

            # v_t = v + sign(g_t^2-v)(g_t^2)
            v = self.get_slot(var, "v")
            grad2 = grad * grad
            v_slice = tf.gather(v, indices)
            if self._activation == "sign":
                sign = tf.sign(grad2 - tf.gather(v, indices))
            elif self._activation == "tanh":
                sign = tf.tanh(10 * (grad2 - tf.gather(v, indices)))
            else:
                raise NotImplementedError("Activation function can be sign or tanh")
            v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
            v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
            v_sqrt = tf.sqrt(v_scaled_g_values)

            # Yogi effective LR
            per_coord_lr = lr / (v_sqrt + epsilon_t)

            # Variable update
            # Step 1: Gradient descent
            var_slice = tf.gather(var, indices)
            new_var = var_slice - per_coord_lr * m_slice
            # Step 2: Prox operator
            if self._l1_regularization_strength > 0:
                new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
            elif self._l2_regularization_strength > 0:
                new_var = new_var / (1 + l2_t * per_coord_lr)
            # Step 3: Update
            var_update = self._resource_scatter_update(var, indices, new_var)
            update_vs.append(var_update)
            update_vs.append(m_t)
            update_vs.append(v_t)

        # Create an op that groups all the above operations
        return tf.group(*update_vs)