in kfac/python/ops/fisher_factors.py [0:0]
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
# TODO(b/69918258): Add correctness tests for this method.
# pylint: disable=invalid-name
ops = []
if (len(self._option1quants_by_damping) +
len(self._option2quants_by_damping)):
# Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
# the pseudo-code in the original paper. Because the computations for
# the A and G case are essentially the same they can both be performed by
# the same class (this one).
C1 = self.cov_dt1
# Get the eigendecomposition of C0 (= self.cov)
eigen_e, eigen_V = self.get_eigendecomp()
# TODO(b/69678661): Note, there is an implicit assumption here that C1
# and C0 (as represented here by its eigen-decomp) are consistent. This
# could fail to be the case if self._cov and self._cov_dt1 are not updated
# consistently, or are somehow read between or during the cov updates.
# Can this possibly happen? Is there a way to prevent it?
for damping_id, (Lmat_var,
psi_var) in self._option1quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
damping = tf.cast(damping, self._dtype)
invsqrtC0 = tf.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
# Might need to enforce symmetry lost due to numerical issues.
invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0
# The following line imposes the symmetry assumed by "Option 1" on C1.
# Strangely the code can work okay with this line commented out,
# depending on how psd_eig is defined. I'm not sure why.
C1 = (C1 + tf.transpose(C1)) / 2.0
# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
hPsi = tf.matmul(tf.matmul(invsqrtC0, C1), invsqrtC0)
# Compute the decomposition U*diag(psi)*U^T = hPsi
psi, U = utils.posdef_eig(hPsi)
# L = C0^(-1/2) * U
Lmat = tf.matmul(invsqrtC0, U)
ops.append(utils.smart_assign(Lmat_var, Lmat))
ops.append(utils.smart_assign(psi_var, psi))
for damping_id, (Pmat_var, Kmat_var,
mu_var) in self._option2quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
damping = tf.cast(damping, self._dtype)
# compute C0^(-1/2)
invsqrtC0 = tf.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
# Might need to enforce symmetry lost due to numerical issues.
invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0
# Compute the product C0^(-1/2) * C1
invsqrtC0C1 = tf.matmul(invsqrtC0, C1)
# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
hPsi = tf.matmul(invsqrtC0C1, invsqrtC0)
# Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
# Note that we using the notation mu instead of "m" for the eigenvalues.
# Instead of computing the product hPsi^T * hPsi and then doing an
# eigen-decomposition of this we just compute the SVD of hPsi and then
# square the singular values to get the eigenvalues. For a justification
# of this approach, see:
# https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
sqrtmu, _, E = tf.svd(hPsi)
mu = tf.square(sqrtmu)
# Mathematically, the eigenvalues should not should not exceed 1.0, but
# due to numerical issues, or possible issues with inconsistent
# values of C1 and (the eigen-decomposition of) C0 they might. So
# we enforce this condition.
mu = tf.minimum(mu, 1.0)
# P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
Pmat = tf.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)
# K = C_0^(-1/2) * E
Kmat = tf.matmul(invsqrtC0, E)
ops.append(utils.smart_assign(Pmat_var, Pmat))
ops.append(utils.smart_assign(Kmat_var, Kmat))
ops.append(utils.smart_assign(mu_var, mu))
ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
return [tf.group(*ops)]