def make_inverse_update_ops()

in kfac/python/ops/fisher_factors.py [0:0]


  def make_inverse_update_ops(self):
    """Create and return update ops corresponding to registered computations."""
    # TODO(b/69918258): Add correctness tests for this method.
    # pylint: disable=invalid-name

    ops = []

    if (len(self._option1quants_by_damping) +
        len(self._option2quants_by_damping)):

      # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
      # the pseudo-code in the original paper.  Because the computations for
      # the A and G case are essentially the same they can both be performed by
      # the same class (this one).

      C1 = self.cov_dt1

      # Get the eigendecomposition of C0  (= self.cov)
      eigen_e, eigen_V = self.get_eigendecomp()

      # TODO(b/69678661): Note, there is an implicit assumption here that C1
      # and C0 (as represented here by its eigen-decomp) are consistent.  This
      # could fail to be the case if self._cov and self._cov_dt1 are not updated
      # consistently, or are somehow read between or during the cov updates.
      # Can this possibly happen?  Is there a way to prevent it?

      for damping_id, (Lmat_var,
                       psi_var) in self._option1quants_by_damping.items():

        damping = self._damping_funcs_by_id[damping_id]()
        damping = tf.cast(damping, self._dtype)

        invsqrtC0 = tf.matmul(
            eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)

        # Might need to enforce symmetry lost due to numerical issues.
        invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0

        # The following line imposes the symmetry assumed by "Option 1" on C1.
        # Strangely the code can work okay with this line commented out,
        # depending on how psd_eig is defined.  I'm not sure why.
        C1 = (C1 + tf.transpose(C1)) / 2.0

        # hPsi = C0^(-1/2) * C1 * C0^(-1/2)  (hPsi means hat{Psi})
        hPsi = tf.matmul(tf.matmul(invsqrtC0, C1), invsqrtC0)

        # Compute the decomposition U*diag(psi)*U^T = hPsi
        psi, U = utils.posdef_eig(hPsi)

        # L = C0^(-1/2) * U
        Lmat = tf.matmul(invsqrtC0, U)

        ops.append(utils.smart_assign(Lmat_var, Lmat))
        ops.append(utils.smart_assign(psi_var, psi))

      for damping_id, (Pmat_var, Kmat_var,
                       mu_var) in self._option2quants_by_damping.items():

        damping = self._damping_funcs_by_id[damping_id]()
        damping = tf.cast(damping, self._dtype)

        # compute C0^(-1/2)
        invsqrtC0 = tf.matmul(
            eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)

        # Might need to enforce symmetry lost due to numerical issues.
        invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0

        # Compute the product C0^(-1/2) * C1
        invsqrtC0C1 = tf.matmul(invsqrtC0, C1)

        # hPsi = C0^(-1/2) * C1 * C0^(-1/2)  (hPsi means hat{Psi})
        hPsi = tf.matmul(invsqrtC0C1, invsqrtC0)

        # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
        # Note that we using the notation mu instead of "m" for the eigenvalues.
        # Instead of computing the product hPsi^T * hPsi and then doing an
        # eigen-decomposition of this we just compute the SVD of hPsi and then
        # square the singular values to get the eigenvalues. For a justification
        # of this approach, see:
        # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
        sqrtmu, _, E = tf.svd(hPsi)
        mu = tf.square(sqrtmu)

        # Mathematically, the eigenvalues should not should not exceed 1.0, but
        # due to numerical issues, or possible issues with inconsistent
        # values of C1 and (the eigen-decomposition of) C0 they might. So
        # we enforce this condition.
        mu = tf.minimum(mu, 1.0)

        # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
        Pmat = tf.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)

        # K = C_0^(-1/2) * E
        Kmat = tf.matmul(invsqrtC0, E)

        ops.append(utils.smart_assign(Pmat_var, Pmat))
        ops.append(utils.smart_assign(Kmat_var, Kmat))
        ops.append(utils.smart_assign(mu_var, mu))

    ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
    return [tf.group(*ops)]