def make_inverse_update_ops()

in kfac/python/ops/fisher_factors.py [0:0]


  def make_inverse_update_ops(self):
    """Create and return update ops corresponding to registered computations."""
    ops = []

    num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
                       if exp == -1)

    num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses

    other_matrix_power_registered = num_other_matpower >= 1

    use_eig = (
        self._eigendecomp or other_matrix_power_registered or
        num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)

    # We precompute these so we don't need to evaluate them multiple times (for
    # each matrix power that uses them)
    damping_value_by_id = {damping_id: tf.cast(
        self._damping_funcs_by_id[damping_id](), self._dtype)
                           for damping_id in self._damping_funcs_by_id}

    if use_eig:
      eigenvalues, eigenvectors = self.get_eigendecomp()  # pylint: disable=unpacking-non-sequence

      for (exp, damping_id), matpower in (
          self._matpower_by_exp_and_damping.items()):
        damping = damping_value_by_id[damping_id]
        ops.append(
            utils.smart_assign(
                matpower,
                tf.matmul(eigenvectors * (eigenvalues + damping)**exp,
                          tf.transpose(eigenvectors))))
      # These ops share computation and should be run on a single device.
      ops = [tf.group(*ops)]
    else:
      for (exp, damping_id), matpower in (
          self._matpower_by_exp_and_damping.items()):
        assert exp == -1
        damping = damping_value_by_id[damping_id]
        ops.append(
            utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping)))

    # TODO(b/77902055): If inverses are being computed with Cholesky's
    # we can share the work. Instead this code currently just computes the
    # Cholesky a second time. It does at least share work between requests for
    # Cholesky's and Cholesky inverses with the same damping id.
    for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
      cholesky_ops = []

      damping = damping_value_by_id[damping_id]
      cholesky_value = utils.cholesky(self.cov, damping)

      if damping_id in self._cholesky_by_damping:
        cholesky = self._cholesky_by_damping[damping_id]
        cholesky_ops.append(utils.smart_assign(cholesky, cholesky_value))

      identity = tf.eye(
          cholesky_value.shape.as_list()[0], dtype=cholesky_value.dtype)
      cholesky_inv_value = tf.matrix_triangular_solve(cholesky_value, identity)
      cholesky_ops.append(utils.smart_assign(cholesky_inv, cholesky_inv_value))

      ops.append(tf.group(*cholesky_ops))

    for damping_id, cholesky in self._cholesky_by_damping.items():
      if damping_id not in self._cholesky_inverse_by_damping:
        damping = damping_value_by_id[damping_id]
        cholesky_value = utils.cholesky(self.cov, damping)
        ops.append(utils.smart_assign(cholesky, cholesky_value))

    self._eigendecomp = False
    return ops