tensorflow_addons/optimizers/adabelief.py [211:233]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
            v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
        else:
            vhat_t = None
            v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))

        if self.rectify:
            r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf
            r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t
            r_t = tf.sqrt(r_t_numerator / r_t_denominator)
            sma_threshold = self._get_hyper("sma_threshold", var_dtype)
            var_t = tf.where(
                sma_t >= sma_threshold,
                r_t * m_corr_t / (v_corr_t + epsilon_t),
                m_corr_t,
            )
        else:
            var_t = m_corr_t / (v_corr_t + epsilon_t)

        if self._has_weight_decay:
            var_t += wd_t * var
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



tensorflow_addons/optimizers/adabelief.py [282:304]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
            v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
        else:
            vhat_t = None
            v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))

        if self.rectify:
            r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf
            r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t
            r_t = tf.sqrt(r_t_numerator / r_t_denominator)
            sma_threshold = self._get_hyper("sma_threshold", var_dtype)
            var_t = tf.where(
                sma_t >= sma_threshold,
                r_t * m_corr_t / (v_corr_t + epsilon_t),
                m_corr_t,
            )
        else:
            var_t = m_corr_t / (v_corr_t + epsilon_t)

        if self._has_weight_decay:
            var_t += wd_t * var
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



