in kfac/python/ops/fisher_factors.py [0:0]
def get_matpower(self, exp, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# self.cov (except when exp == 1).
if exp == 1:
return self._make_cov_linear_operator(
damping=tf.cast(damping_func(), dtype=self._dtype))
elif exp == -1:
damping_id = graph_func_to_id(damping_func)
cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)]
damping_value = self._damping_var_by_id[damping_id]
# Replicates the in_channels * in_channels cov inverse matrix.
# Note that in this function the replications are not done explicitly.
# They are done using tf.linalg ops and hence they are computationally
# efficient.
quant_1 = tf.linalg.LinearOperatorKronecker([
tf.linalg.LinearOperatorFullMatrix(
cov_inv,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True),
tf.linalg.LinearOperatorIdentity(
num_rows=self._kw_kh, dtype=self._dtype)
])
# If a bias dimension needs to be appended then we need to expand
# scaled_cov_inv_mu and assign `1` to the last dimension. Also
# we need to append inverse of damping constant (1 * 1 matrix) to
# to the replicated cov inverse matrix.
if self._has_bias:
bias_operator = tf.linalg.LinearOperatorFullMatrix(
[[1. / damping_value]],
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
cov_inv_kron_identity_operator = tf.linalg.LinearOperatorBlockDiag(
[quant_1, bias_operator])
if not ASSUME_ZERO_MEAN_ACTIVATIONS:
cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id]
scale = self._rank_one_update_scale_by_damping_id[damping_id]
# Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last
# dim and then reshape.
mean_update = (
tf.expand_dims(
append_homog(
tf.reshape(tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1,)),
homog_value=(1. / damping_value)),
axis=1))
else:
cov_inv_kron_identity_operator = quant_1
if not ASSUME_ZERO_MEAN_ACTIVATIONS:
cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id]
scale = self._rank_one_update_scale_by_damping_id[damping_id]
# Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last
# dim and then reshape.
mean_update = tf.reshape(
tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1, 1))
if ASSUME_ZERO_MEAN_ACTIVATIONS:
return cov_inv_kron_identity_operator
else:
# To include the contribution from the mean activations we need to
# low rank update op. Note the Sherman Morrison formula requires
# negative of (mean_update * mean_update^T) / scale term to be added.
# In order to achieve this using `LinearOperatorLowRankUpdate` set `v`
# to negative of mean update vector multiplied by scale.
return tf.linalg.LinearOperatorLowRankUpdate(
cov_inv_kron_identity_operator,
mean_update,
v=-scale * mean_update,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
else:
raise ValueError("ConvInputSUAKroneckerFactor only supports"
"computing inverse of cov matrix.")