in kfac/python/ops/fisher_factors.py [0:0]
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
if exp == -1)
num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
other_matrix_power_registered = num_other_matpower >= 1
use_eig = (
self._eigendecomp or other_matrix_power_registered or
num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
# We precompute these so we don't need to evaluate them multiple times (for
# each matrix power that uses them)
damping_value_by_id = {damping_id: tf.cast(
self._damping_funcs_by_id[damping_id](), self._dtype)
for damping_id in self._damping_funcs_by_id}
if use_eig:
eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
for (exp, damping_id), matpower in (
self._matpower_by_exp_and_damping.items()):
damping = damping_value_by_id[damping_id]
ops.append(
utils.smart_assign(
matpower,
tf.matmul(eigenvectors * (eigenvalues + damping)**exp,
tf.transpose(eigenvectors))))
# These ops share computation and should be run on a single device.
ops = [tf.group(*ops)]
else:
for (exp, damping_id), matpower in (
self._matpower_by_exp_and_damping.items()):
assert exp == -1
damping = damping_value_by_id[damping_id]
ops.append(
utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping)))
# TODO(b/77902055): If inverses are being computed with Cholesky's
# we can share the work. Instead this code currently just computes the
# Cholesky a second time. It does at least share work between requests for
# Cholesky's and Cholesky inverses with the same damping id.
for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
cholesky_ops = []
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.cov, damping)
if damping_id in self._cholesky_by_damping:
cholesky = self._cholesky_by_damping[damping_id]
cholesky_ops.append(utils.smart_assign(cholesky, cholesky_value))
identity = tf.eye(
cholesky_value.shape.as_list()[0], dtype=cholesky_value.dtype)
cholesky_inv_value = tf.matrix_triangular_solve(cholesky_value, identity)
cholesky_ops.append(utils.smart_assign(cholesky_inv, cholesky_inv_value))
ops.append(tf.group(*cholesky_ops))
for damping_id, cholesky in self._cholesky_by_damping.items():
if damping_id not in self._cholesky_inverse_by_damping:
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.cov, damping)
ops.append(utils.smart_assign(cholesky, cholesky_value))
self._eigendecomp = False
return ops