def _compute_raw_update_steps()

in kfac/python/ops/optimizer.py [0:0]


  def _compute_raw_update_steps(self, grads_and_vars):
    """Computes the raw update steps for the variables given the gradients.

    Note that these "raw updates" are further multiplied by
    -1*self._learning_rate when the update is eventually applied in the
    superclass (which is GradientDescentOptimizer).

    Args:
      grads_and_vars: List of (gradient, variable) pairs.

    Returns:
      A list of tuples (raw_update, var) where raw_update is the update to
      the parameter. These updates must be actually used since they carry
      with them certain control dependencies that need to happen.
    """

    if self._momentum_type == "regular":
      # Compute "preconditioned" gradient.
      precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars)

      # Apply "KL clipping" if asked for.
      if self._norm_constraint is not None:
        precon_grads_and_vars = self._clip_updates(grads_and_vars,
                                                   precon_grads_and_vars)

      # Update the velocities and get their values as the "raw" updates
      raw_updates_and_vars = self._update_velocities(precon_grads_and_vars,
                                                     self._momentum)

      if self._adapt_damping and self._is_chief:

        def compute_qmodel_change():
          updates_and_vars = sprod_p(-1. * self._learning_rate,
                                     raw_updates_and_vars)
          return self._compute_approx_qmodel_change(updates_and_vars,
                                                    grads_and_vars)

        maybe_update_qmodel_change = self._maybe_update_qmodel_change(
            compute_qmodel_change)

        with tf.control_dependencies([maybe_update_qmodel_change]):
          # Making this a tuple is important so that it actually gets evaluated
          # in the context.
          return tuple((tf.identity(vec), var)
                       for (vec, var) in raw_updates_and_vars)
      else:
        return raw_updates_and_vars

    elif self._momentum_type == "adam":
      velocities_and_vars = self._update_velocities(grads_and_vars,
                                                    self._momentum)
      # The "preconditioned" velocity vector is the raw update step.
      raw_updates_and_vars = self._multiply_preconditioner(velocities_and_vars)

      # Apply "KL clipping" if asked for. Note that we are applying this to
      # the combined preconditioned gradient + velocity, unlike for the
      # momentum_type = 'regular' case.
      if self._norm_constraint is not None:
        raw_updates_and_vars = self._clip_updates(velocities_and_vars,
                                                  raw_updates_and_vars)

      if self._adapt_damping and self._is_chief:
        def compute_qmodel_change():
          # This is a special formula that exploits the structure of the
          # particular update we are using.  Note that this is using the approx
          # Fisher as defined by the inverses, which might be stale (perhaps so
          # stale that they are using an old damping value, which may mess up
          # the damping adaptation method).
          return (0.5 * (self._learning_rate**2) *
                  ip_p(raw_updates_and_vars, velocities_and_vars) -
                  self._learning_rate * ip_p(raw_updates_and_vars,
                                             grads_and_vars))

        maybe_update_qmodel_change = self._maybe_update_qmodel_change(
            compute_qmodel_change)

        with tf.control_dependencies([maybe_update_qmodel_change]):
          # Making this a tuple is important so that it actually gets evaluated
          # in the context.
          return tuple((tf.identity(vec), var)
                       for (vec, var) in raw_updates_and_vars)
      else:
        return raw_updates_and_vars

    elif (self._momentum_type == "qmodel"
          or self._momentum_type == "qmodel_fixedmu"):

      precon_grads_and_vars, m, c, b = self._get_qmodel_quantities(
          grads_and_vars)

      if self._momentum_type == "qmodel_fixedmu":
        fixed_mu = self._momentum
      else:
        fixed_mu = None

      # Compute optimal velocity update parameters according to quadratic
      # model
      alpha, mu, qmodel_change = self._compute_qmodel_hyperparams(
          m, c, b, fixed_mu=fixed_mu)

      qmodel_assign_op = tf.group(
          utils.smart_assign(self._qmodel_change, qmodel_change,
                             force_cast=True),
          utils.smart_assign(self._qmodel_learning_rate, -alpha,
                             force_cast=True),
          utils.smart_assign(self._qmodel_momentum, mu,
                             force_cast=True))

      with tf.control_dependencies([qmodel_assign_op]):
        return self._update_velocities(
            precon_grads_and_vars, mu, vec_coeff=-alpha)