in tensorflow_addons/optimizers/yogi.py [0:0]
def _resource_apply_sparse(self, grad, var, indices):
"""Applies sparse gradients to a variable.
Args:
grad: A tensor for the `values` of `tf.IndexedSlices`.
var: A `tf.Variable` object.
indices: A tensor for the `indices` of `tf.IndexedSlices`.
Returns:
An op which updates `var` with `grad` and `indices`.
"""
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta1_t = self._get_hyper("beta_1", var_dtype)
beta2_t = self._get_hyper("beta_2", var_dtype)
epsilon_t = self._get_hyper("epsilon", var_dtype)
l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta1_power = tf.pow(beta1_t, local_step)
beta2_power = tf.pow(beta2_t, local_step)
lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)
update_vs = []
if self._beta1 == 0.0:
# v_t = v + sign(g_t^2-v)(g_t^2)
v = self.get_slot(var, "v")
grad2 = grad * grad
v_slice = tf.gather(v, indices)
if self._activation == "sign":
sign = tf.sign(grad2 - v_slice)
elif self._activation == "tanh":
sign = tf.tanh(10 * (grad2 - v_slice))
else:
raise NotImplementedError("Activation function can be sign or tanh")
v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
v_sqrt = tf.sqrt(v_scaled_g_values)
# Yogi effective LR
per_coord_lr = lr / (v_sqrt + epsilon_t)
# Variable update
# Step 1: Gradient descent
var_slice = tf.gather(var, indices)
new_var = var_slice - per_coord_lr * grad
# Step 2: Prox operator
if self._l1_regularization_strength > 0:
new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
elif self._l2_regularization_strength > 0:
new_var = new_var / (1 + l2_t * per_coord_lr)
# Step 3: Update
var_update = self._resource_scatter_update(var, indices, new_var)
update_vs.append(var_update)
update_vs.append(v_t)
else:
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = m.assign(m * beta1_t, use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_slice = tf.gather(m, indices) + m_scaled_g_values
m_t = self._resource_scatter_update(m, indices, m_slice)
# v_t = v + sign(g_t^2-v)(g_t^2)
v = self.get_slot(var, "v")
grad2 = grad * grad
v_slice = tf.gather(v, indices)
if self._activation == "sign":
sign = tf.sign(grad2 - tf.gather(v, indices))
elif self._activation == "tanh":
sign = tf.tanh(10 * (grad2 - tf.gather(v, indices)))
else:
raise NotImplementedError("Activation function can be sign or tanh")
v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
v_sqrt = tf.sqrt(v_scaled_g_values)
# Yogi effective LR
per_coord_lr = lr / (v_sqrt + epsilon_t)
# Variable update
# Step 1: Gradient descent
var_slice = tf.gather(var, indices)
new_var = var_slice - per_coord_lr * m_slice
# Step 2: Prox operator
if self._l1_regularization_strength > 0:
new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
elif self._l2_regularization_strength > 0:
new_var = new_var / (1 + l2_t * per_coord_lr)
# Step 3: Update
var_update = self._resource_scatter_update(var, indices, new_var)
update_vs.append(var_update)
update_vs.append(m_t)
update_vs.append(v_t)
# Create an op that groups all the above operations
return tf.group(*update_vs)