in tensorflow_addons/optimizers/yogi.py [0:0]
def _resource_apply_dense(self, grad, var):
"""See `tf.train.Optimizer._apply_dense()`."""
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta1_t = self._get_hyper("beta_1", var_dtype)
beta2_t = self._get_hyper("beta_2", var_dtype)
epsilon_t = self._get_hyper("epsilon", var_dtype)
l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta1_power = tf.pow(beta1_t, local_step)
beta2_power = tf.pow(beta2_t, local_step)
lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)
update_vs = []
if self._beta1 == 0.0:
# v_t = v + sign(g_t^2-v)(g_t^2)
v = self.get_slot(var, "v")
grad2 = grad * grad
if self._activation == "sign":
sign = tf.sign(grad2 - v)
elif self._activation == "tanh":
sign = tf.tanh(10 * (grad2 - v))
else:
raise NotImplementedError("Activation function can be sign or tanh")
v_t = v.assign_add(
(1 - beta2_t) * sign * grad2, use_locking=self._use_locking
)
v_sqrt = tf.sqrt(v_t)
# Yogi effective LR
per_coord_lr = lr / (v_sqrt + epsilon_t)
# Variable update
# Step 1: Gradient descent
new_var = var - per_coord_lr * grad
# Step 2: Prox operator
if self._l1_regularization_strength > 0:
new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
elif self._l2_regularization_strength > 0:
new_var = new_var / (1 + l2_t * per_coord_lr)
# Step 3: Update
var_update = var.assign(new_var, use_locking=self._use_locking)
update_vs.append(var_update)
update_vs.append(v_t)
else:
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_t = m.assign(
m * beta1_t + grad * (1 - beta1_t), use_locking=self._use_locking
)
# v_t = v + sign(g_t^2-v)(g_t^2)
v = self.get_slot(var, "v")
grad2 = grad * grad
if self._activation == "sign":
sign = tf.sign(grad2 - v)
elif self._activation == "tanh":
sign = tf.tanh(10 * (grad2 - v))
else:
raise NotImplementedError("Activation function can be sign or tanh")
v_t = v.assign_add(
(1 - beta2_t) * sign * grad2, use_locking=self._use_locking
)
v_sqrt = tf.sqrt(v_t)
# Yogi effective LR
per_coord_lr = lr / (v_sqrt + epsilon_t)
# Variable update
# Step 1: Gradient descent
new_var = var - per_coord_lr * m_t
# Step 2: Prox operator
if self._l1_regularization_strength > 0:
new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
elif self._l2_regularization_strength > 0:
new_var = new_var / (1 + l2_t * per_coord_lr)
# Step 3: Update
var_update = var.assign(new_var, use_locking=self._use_locking)
update_vs.append(var_update)
update_vs.append(m_t)
update_vs.append(v_t)
# Create an op that groups all the above operations
return tf.group(*update_vs)