in tensorflow_addons/optimizers/rectified_adam.py [0:0]
def _prepare_local(self, var_device, var_dtype, apply_state):
super()._prepare_local(var_device, var_dtype, apply_state)
lr_t = self._decayed_lr(var_dtype)
wd_t = self._decayed_wd(var_dtype)
beta_1_t = self._get_hyper("beta_1", var_dtype)
beta_2_t = self._get_hyper("beta_2", var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_power = tf.pow(beta_1_t, local_step)
beta_2_power = tf.pow(beta_2_t, local_step)
one_minus_beta_1_t = 1.0 - beta_1_t
recip_one_minus_beta_1_power = 1.0 / (1.0 - beta_1_power)
one_minus_beta_2_t = 1.0 - beta_2_t
recip_one_minus_beta_2_power = 1.0 / (1.0 - beta_2_power)
sma_inf = 2.0 / one_minus_beta_2_t - 1.0
sma_t = sma_inf - 2.0 * local_step * beta_2_power * recip_one_minus_beta_2_power
r_t = tf.sqrt(
(sma_t - 4.0)
/ (sma_inf - 4.0)
* (sma_t - 2.0)
/ (sma_inf - 2.0)
* sma_inf
/ sma_t
)
sma_threshold = self._get_hyper("sma_threshold", var_dtype)
sma_t_ge_sma_threshold = sma_t >= sma_threshold
if self._initial_total_steps > 0:
total_steps = self._get_hyper("total_steps", var_dtype)
warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
min_lr = self._get_hyper("min_lr", var_dtype)
decay_steps = tf.maximum(total_steps - warmup_steps, 1)
decay_rate = (min_lr - lr_t) / decay_steps
lr_t = tf.where(
local_step <= warmup_steps,
lr_t * (local_step / warmup_steps),
lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
)
apply_state[(var_device, var_dtype)].update(
dict(
lr_t=lr_t,
wd_t=wd_t,
beta_1_t=beta_1_t,
beta_2_t=beta_2_t,
epsilon_t=tf.convert_to_tensor(self.epsilon, var_dtype),
local_step=local_step,
beta_1_power=beta_1_power,
beta_2_power=beta_2_power,
sma_inf=sma_inf,
sma_t=sma_t,
one_minus_beta_1_t=one_minus_beta_1_t,
recip_one_minus_beta_1_power=recip_one_minus_beta_1_power,
one_minus_beta_2_t=one_minus_beta_2_t,
recip_one_minus_beta_2_power=recip_one_minus_beta_2_power,
r_t=r_t,
sma_t_ge_sma_threshold=sma_t_ge_sma_threshold,
)
)