def _compute_qmodel_hyperparams()

in kfac/python/ops/optimizer.py [0:0]


  def _compute_qmodel_hyperparams(self, m, c, b, fixed_mu=None):
    """Compute optimal update hyperparameters from the quadratic model.

    More specifically, if L is the loss we minimize a quadratic approximation
    of L(theta + d) which we denote by qmodel(d) with
    d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where

      qmodel(d) = (1/2) * d^T * C * d + grad^T*d + L(theta) .

    Unlike in the KL clipping approach we use the non-approximated quadratic
    model where the curvature matrix C is the true Fisher (or GGN) on the
    current mini-batch (computed without any approximations beyond mini-batch
    sampling), with the usual Tikhonov damping/regularization applied,

      C = F + damping * I

    See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
    the formula.  See Appendix C for a discussion of the trick of using
    a factorized Fisher matrix to more efficiently compute the required
    vector-matrix-vector products.

    Args:
      m: 2 by 2 matrix representing the quadratic term (a list of list of
        0D Tensors)
      c: a 2 by 1 vector representing the linear term (a list of 0D Tensors)
      b: 2 by 2 matrix representing only the contribution of the damping to the
        quadratic term
      fixed_mu: A fixed value of mu to use instead of the optimal one.
        (Default: None)
    Returns:
      (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
      quadratic model, and
      qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
                    = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
    """

    def non_zero_prevupd_case():
      r"""Computes optimal (alpha, mu) given non-zero previous update.

      We solve the full 2x2 linear system. See Martens & Grosse (2015),
      Section 7, definition of $\alpha^*$ and $\mu^*$.

      Returns:
        (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
        the quadratic model, and
        qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
      """
      if fixed_mu is None:
        sol = -1. * _two_by_two_solve(m, c)
        alpha = sol[0, 0]
        mu = sol[1, 0]

        if self._qmodel_update_rescale is None:
          # This is a special formula that takes advantage of the particular
          # relationship of sol to m and c. It should be equivalent to
          # _eval_quadratic(m, c, sol) if everything is working properly.
          qmodel_change = 0.5 * tf.reduce_sum(sol * c)
        else:
          sol = self._qmodel_update_rescale * sol
          qmodel_change = _eval_quadratic(m, c, sol)

        # Subtract out the damping-related penalty
        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
                            * _eval_quadratic_no_c(b, sol))

      else:
        alpha = -1. * (fixed_mu * m[0][1] + c[0][0]) / (m[0][0])
        mu = fixed_mu

        sol = [[alpha], [mu]]

        if self._qmodel_update_rescale is not None:
          sol = self._qmodel_update_rescale * tf.convert_to_tensor(sol)

        qmodel_change = _eval_quadratic(m, c, sol)

        # Subtract out the damping-related penalty
        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
                            * _eval_quadratic_no_c(b, sol))

      return tf.squeeze(alpha), tf.squeeze(mu), tf.squeeze(qmodel_change)

    def zero_prevupd_case():
      r"""Computes optimal (alpha, mu) given all-zero previous update.

      The linear system reduces to 1x1. See Martens & Grosse (2015),
      Section 6.4, definition of $\alpha^*$.

      Returns:
        (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
        quadratic model, and
        qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
      """
      alpha = -c[0][0] / m[0][0]
      if fixed_mu is None:
        mu = 0.0
      else:
        mu = fixed_mu

      mu = tf.cast(mu, dtype=alpha.dtype)

      if self._qmodel_update_rescale is None:
        # This is a special formula that takes advantage of the particular
        # relationship of sol to m and c.
        qmodel_change = 0.5 * alpha * c[0][0]

        # Subtract out the damping-related penalty
        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
                            * 0.5 * tf.square(alpha) * b[0][0])
      else:
        sol = self._qmodel_update_rescale * alpha
        qmodel_change = 0.5 * m[0][0] * tf.square(sol) + c[0][0] * sol
        # Subtract out the damping-related penalty
        if not _INCLUDE_DAMPING_IN_QMODEL_CHANGE:
          qmodel_change -= (self._sub_damping_out_qmodel_change_coeff
                            * 0.5 * tf.square(sol) * b[0][0])

      return alpha, mu, qmodel_change

    return tf.cond(
        tf.equal(c[1][0], 0.0),
        zero_prevupd_case,
        non_zero_prevupd_case)