def _compute_qmodel()

in kfac/python/ops/optimizer.py [0:0]


  def _compute_qmodel(self,
                      raw_updates_and_vars,
                      prev_updates_and_vars,
                      grads_and_vars):
    """Computes the 2 dimensional version of the (exact) quadratic model.

       The two dimesions are the update and the previous update vectors.

       The arguments are all lists of (Tensor, Variable) pairs where the
       variables are the same and in the same order.

    Args:
      raw_updates_and_vars: a list of (precond grad, variable) pairs. Raw update
        proposal to apply to the variables (before scaling by learning rate and
        addition of velocity/momentum).
      prev_updates_and_vars: a list of (previous update, variable) pairs.
        Previous update applied to the variables (includes learning rate and
        velocity/momentum).
      grads_and_vars: a list of (gradient, variable) pairs. Gradients for the
        parameters and the variables that the updates are being applied to. The
        order of this list must correspond to the order of the other arguments.
        (Note that this function doesn't actually apply the update.)

    Returns:
      m, c, and b. m is the 2 by 2 matrix representing the quadratic term,
      c is a 2 by 1 vector representing the linear term, and b is the 2 by 2
      matrix representing only the contribution of the damping to the quadratic
      term. These are all multi-dimensional lists (lists of lists) of Tensors.
    """

    # Raw update proposal to apply to the variables (before scaling by learning
    # rate and addition of velocity/momentum).
    raw_updates, _ = zip(*raw_updates_and_vars)
    prev_updates, _ = zip(*prev_updates_and_vars)
    grads, variables = zip(*grads_and_vars)

    utils.assert_variables_match_pairs_list(
        raw_updates_and_vars, prev_updates_and_vars,
        error_message="_compute_qmodel raw_updates_and_vars and "
        "prev_updates_and_vars differ.")
    utils.assert_variables_match_pairs_list(
        prev_updates_and_vars, grads_and_vars,
        error_message="_compute_qmodel prev_updates_and_vars and "
        "grads_and_vars differ.")

    cmvpc = cmvp.CurvatureMatrixVectorProductComputer(
        self.layers,
        variables,
        colocate_gradients_with_ops=self._colocate_gradients_with_ops)

    # Compute the matrix-vector products with the transposed Fisher factor
    # (or GGN factor)
    if self.mat_type == "Fisher":
      mft_updates = cmvpc.multiply_fisher_factor_transpose(raw_updates)
      mft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
    elif self.mat_type == "GGN" or self.mat_type == "Empirical_Fisher":
      mft_updates = cmvpc.multiply_ggn_factor_transpose(raw_updates)
      mft_prev_updates = cmvpc.multiply_ggn_factor_transpose(prev_updates)

    batch_size = tf.cast(self._batch_size, dtype=mft_updates[0].dtype)

    damping = tf.cast(self.damping, dtype=raw_updates[0].dtype)
    b_11 = damping * ip(raw_updates, raw_updates)
    b_21 = damping * ip(prev_updates, raw_updates)
    b_22 = damping * ip(prev_updates, prev_updates)
    b = [[b_11, b_21], [b_21, b_22]]

    # Compute the entries of the 2x2 matrix
    m_11 = ip(mft_updates, mft_updates) / batch_size
    m_21 = ip(mft_prev_updates, mft_updates) / batch_size
    m_22 = (ip(mft_prev_updates, mft_prev_updates)
            / batch_size)
    m = [[m_11 + b_11, m_21 + b_21],
         [m_21 + b_21, m_22 + b_22]]

    m = utils.all_average(m)

    c_1 = ip(grads, raw_updates)
    c_2 = ip(grads, prev_updates)

    c = [[c_1], [c_2]]

    return m, c, b