in kfac/python/ops/optimizer.py [0:0]
def _compute_qmodel_hyperparams(self, m, c, b, fixed_mu=None):
"""Compute optimal update hyperparameters from the quadratic model.
More specifically, if L is the loss we minimize a quadratic approximation
of L(theta + d) which we denote by qmodel(d) with
d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
qmodel(d) = (1/2) * d^T * C * d + grad^T*d + L(theta) .
Unlike in the KL clipping approach we use the non-approximated quadratic
model where the curvature matrix C is the true Fisher (or GGN) on the
current mini-batch (computed without any approximations beyond mini-batch
sampling), with the usual Tikhonov damping/regularization applied,
C = F + damping * I
See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
the formula. See Appendix C for a discussion of the trick of using
a factorized Fisher matrix to more efficiently compute the required
vector-matrix-vector products.
Args:
m: 2 by 2 matrix representing the quadratic term (a list of list of
0D Tensors)
c: a 2 by 1 vector representing the linear term (a list of 0D Tensors)
b: 2 by 2 matrix representing only the contribution of the damping to the
quadratic term
fixed_mu: A fixed value of mu to use instead of the optimal one.
(Default: None)
Returns:
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
quadratic model, and
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
"""
def non_zero_prevupd_case():
r"""Computes optimal (alpha, mu) given non-zero previous update.
We solve the full 2x2 linear system. See Martens & Grosse (2015),
Section 7, definition of $\alpha^*$ and $\mu^*$.
Returns:
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
the quadratic model, and
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
"""
if fixed_mu is None:
sol = -1. * _two_by_two_solve(m, c)
alpha = sol[0, 0]
mu = sol[1, 0]
if self._qmodel_update_rescale is None:
# This is a special formula that takes advantage of the particular
# relationship of sol to m and c. It should be equivalent to
# _eval_quadratic(m, c, sol) if everything is working properly.
qmodel_change = 0.5 * tf.reduce_sum(sol * c)
else:
sol = self._qmodel_update_rescale * sol
qmodel_change = _eval_quadratic(m, c, sol)
# Subtract out the damping-related penalty
if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
* _eval_quadratic_no_c(b, sol))
else:
alpha = -1. * (fixed_mu * m[0][1] + c[0][0]) / (m[0][0])
mu = fixed_mu
sol = [[alpha], [mu]]
if self._qmodel_update_rescale is not None:
sol = self._qmodel_update_rescale * tf.convert_to_tensor(sol)
qmodel_change = _eval_quadratic(m, c, sol)
# Subtract out the damping-related penalty
if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
* _eval_quadratic_no_c(b, sol))
return tf.squeeze(alpha), tf.squeeze(mu), tf.squeeze(qmodel_change)
def zero_prevupd_case():
r"""Computes optimal (alpha, mu) given all-zero previous update.
The linear system reduces to 1x1. See Martens & Grosse (2015),
Section 6.4, definition of $\alpha^*$.
Returns:
(alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
quadratic model, and
qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
"""
alpha = -c[0][0] / m[0][0]
if fixed_mu is None:
mu = 0.0
else:
mu = fixed_mu
mu = tf.cast(mu, dtype=alpha.dtype)
if self._qmodel_update_rescale is None:
# This is a special formula that takes advantage of the particular
# relationship of sol to m and c.
qmodel_change = 0.5 * alpha * c[0][0]
# Subtract out the damping-related penalty
if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
* 0.5 * tf.square(alpha) * b[0][0])
else:
sol = self._qmodel_update_rescale * alpha
qmodel_change = 0.5 * m[0][0] * tf.square(sol) + c[0][0] * sol
# Subtract out the damping-related penalty
if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
* 0.5 * tf.square(sol) * b[0][0])
return alpha, mu, qmodel_change
return tf.cond(
tf.equal(c[1][0], 0.0),
zero_prevupd_case,
non_zero_prevupd_case)