def _resource_apply_dense()

in tensorflow_addons/optimizers/yogi.py [0:0]


    def _resource_apply_dense(self, grad, var):
        """See `tf.train.Optimizer._apply_dense()`."""
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta1_t = self._get_hyper("beta_1", var_dtype)
        beta2_t = self._get_hyper("beta_2", var_dtype)
        epsilon_t = self._get_hyper("epsilon", var_dtype)
        l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
        l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)
        beta1_power = tf.pow(beta1_t, local_step)
        beta2_power = tf.pow(beta2_t, local_step)

        lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)

        update_vs = []
        if self._beta1 == 0.0:
            # v_t = v + sign(g_t^2-v)(g_t^2)
            v = self.get_slot(var, "v")
            grad2 = grad * grad
            if self._activation == "sign":
                sign = tf.sign(grad2 - v)
            elif self._activation == "tanh":
                sign = tf.tanh(10 * (grad2 - v))
            else:
                raise NotImplementedError("Activation function can be sign or tanh")
            v_t = v.assign_add(
                (1 - beta2_t) * sign * grad2, use_locking=self._use_locking
            )
            v_sqrt = tf.sqrt(v_t)

            # Yogi effective LR
            per_coord_lr = lr / (v_sqrt + epsilon_t)

            # Variable update
            # Step 1: Gradient descent
            new_var = var - per_coord_lr * grad
            # Step 2: Prox operator
            if self._l1_regularization_strength > 0:
                new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
            elif self._l2_regularization_strength > 0:
                new_var = new_var / (1 + l2_t * per_coord_lr)
            # Step 3: Update
            var_update = var.assign(new_var, use_locking=self._use_locking)

            update_vs.append(var_update)
            update_vs.append(v_t)

        else:
            # m_t = beta1 * m + (1 - beta1) * g_t
            m = self.get_slot(var, "m")
            m_t = m.assign(
                m * beta1_t + grad * (1 - beta1_t), use_locking=self._use_locking
            )

            # v_t = v + sign(g_t^2-v)(g_t^2)
            v = self.get_slot(var, "v")
            grad2 = grad * grad
            if self._activation == "sign":
                sign = tf.sign(grad2 - v)
            elif self._activation == "tanh":
                sign = tf.tanh(10 * (grad2 - v))
            else:
                raise NotImplementedError("Activation function can be sign or tanh")
            v_t = v.assign_add(
                (1 - beta2_t) * sign * grad2, use_locking=self._use_locking
            )
            v_sqrt = tf.sqrt(v_t)

            # Yogi effective LR
            per_coord_lr = lr / (v_sqrt + epsilon_t)

            # Variable update
            # Step 1: Gradient descent
            new_var = var - per_coord_lr * m_t
            # Step 2: Prox operator
            if self._l1_regularization_strength > 0:
                new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
            elif self._l2_regularization_strength > 0:
                new_var = new_var / (1 + l2_t * per_coord_lr)
            # Step 3: Update
            var_update = var.assign(new_var, use_locking=self._use_locking)
            update_vs.append(var_update)
            update_vs.append(m_t)
            update_vs.append(v_t)

        # Create an op that groups all the above operations
        return tf.group(*update_vs)