in kfac/python/ops/optimizer.py [0:0]
def _compute_raw_update_steps(self, grads_and_vars):
"""Computes the raw update steps for the variables given the gradients.
Note that these "raw updates" are further multiplied by
-1*self._learning_rate when the update is eventually applied in the
superclass (which is GradientDescentOptimizer).
Args:
grads_and_vars: List of (gradient, variable) pairs.
Returns:
A list of tuples (raw_update, var) where raw_update is the update to
the parameter. These updates must be actually used since they carry
with them certain control dependencies that need to happen.
"""
if self._momentum_type == "regular":
# Compute "preconditioned" gradient.
precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars)
# Apply "KL clipping" if asked for.
if self._norm_constraint is not None:
precon_grads_and_vars = self._clip_updates(grads_and_vars,
precon_grads_and_vars)
# Update the velocities and get their values as the "raw" updates
raw_updates_and_vars = self._update_velocities(precon_grads_and_vars,
self._momentum)
if self._adapt_damping and self._is_chief:
def compute_qmodel_change():
updates_and_vars = sprod_p(-1. * self._learning_rate,
raw_updates_and_vars)
return self._compute_approx_qmodel_change(updates_and_vars,
grads_and_vars)
maybe_update_qmodel_change = self._maybe_update_qmodel_change(
compute_qmodel_change)
with tf.control_dependencies([maybe_update_qmodel_change]):
# Making this a tuple is important so that it actually gets evaluated
# in the context.
return tuple((tf.identity(vec), var)
for (vec, var) in raw_updates_and_vars)
else:
return raw_updates_and_vars
elif self._momentum_type == "adam":
velocities_and_vars = self._update_velocities(grads_and_vars,
self._momentum)
# The "preconditioned" velocity vector is the raw update step.
raw_updates_and_vars = self._multiply_preconditioner(velocities_and_vars)
# Apply "KL clipping" if asked for. Note that we are applying this to
# the combined preconditioned gradient + velocity, unlike for the
# momentum_type = 'regular' case.
if self._norm_constraint is not None:
raw_updates_and_vars = self._clip_updates(velocities_and_vars,
raw_updates_and_vars)
if self._adapt_damping and self._is_chief:
def compute_qmodel_change():
# This is a special formula that exploits the structure of the
# particular update we are using. Note that this is using the approx
# Fisher as defined by the inverses, which might be stale (perhaps so
# stale that they are using an old damping value, which may mess up
# the damping adaptation method).
return (0.5 * (self._learning_rate**2) *
ip_p(raw_updates_and_vars, velocities_and_vars) -
self._learning_rate * ip_p(raw_updates_and_vars,
grads_and_vars))
maybe_update_qmodel_change = self._maybe_update_qmodel_change(
compute_qmodel_change)
with tf.control_dependencies([maybe_update_qmodel_change]):
# Making this a tuple is important so that it actually gets evaluated
# in the context.
return tuple((tf.identity(vec), var)
for (vec, var) in raw_updates_and_vars)
else:
return raw_updates_and_vars
elif (self._momentum_type == "qmodel"
or self._momentum_type == "qmodel_fixedmu"):
precon_grads_and_vars, m, c, b = self._get_qmodel_quantities(
grads_and_vars)
if self._momentum_type == "qmodel_fixedmu":
fixed_mu = self._momentum
else:
fixed_mu = None
# Compute optimal velocity update parameters according to quadratic
# model
alpha, mu, qmodel_change = self._compute_qmodel_hyperparams(
m, c, b, fixed_mu=fixed_mu)
qmodel_assign_op = tf.group(
utils.smart_assign(self._qmodel_change, qmodel_change,
force_cast=True),
utils.smart_assign(self._qmodel_learning_rate, -alpha,
force_cast=True),
utils.smart_assign(self._qmodel_momentum, mu,
force_cast=True))
with tf.control_dependencies([qmodel_assign_op]):
return self._update_velocities(
precon_grads_and_vars, mu, vec_coeff=-alpha)