in kfac/python/ops/fisher_blocks.py [0:0]
def multiply_matpower(self, vector, exp):
if exp != -1:
raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
"multiplications.")
# pylint: disable=invalid-name
Z = utils.layer_params_to_mat2d(vector)
# Derivations were done for "batch_dim==1" case so we need to convert to
# that orientation:
Z = tf.transpose(Z)
if self._option == SeriesFBApproximation.option1:
# Note that L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.
L_A, psi_A = self._input_factor.get_option1quants(
self._input_damping_func)
L_G, psi_G = self._output_factor.get_option1quants(
self._output_damping_func)
def gamma(x):
# We are assuming that each case has the same number of time-steps.
# If this stops being the case one shouldn't simply replace this T
# with its average value. Instead, one needs to go back to the
# definition of the gamma function from the paper.
T = self._num_timesteps
return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
# Y = \gamma( psi_G*psi_A^T ) (computed element-wise)
# Even though Y is Z-independent we are recomputing it from the psi's
# each since Y depends on both A and G quantities, and it is relatively
# cheap to compute.
Y = gamma(tf.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
# Z = L_G^T * Z * L_A
# This is equivalent to the following computation from the original
# pseudo-code:
# Z = G0^{-1/2} * Z * A0^{-1/2}
# Z = U_G^T * Z * U_A
Z = tf.matmul(L_G, tf.matmul(Z, L_A), transpose_a=True)
# Z = Z .* Y
Z *= Y
# Z = L_G * Z * L_A^T
# This is equivalent to the following computation from the original
# pseudo-code:
# Z = U_G * Z * U_A^T
# Z = G0^{-1/2} * Z * A0^{-1/2}
Z = tf.matmul(L_G, tf.matmul(Z, L_A, transpose_b=True))
elif self._option == SeriesFBApproximation.option2:
# Note that P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1},
# and K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.
P_A, K_A, mu_A = self._input_factor.get_option2quants(
self._input_damping_func)
P_G, K_G, mu_G = self._output_factor.get_option2quants(
self._output_damping_func)
# Our approach differs superficially from the pseudo-code in the paper
# in order to reduce the total number of matrix-matrix multiplies.
# In particular, the first three computations in the pseudo code are
# Z = G0^{-1/2} * Z * A0^{-1/2}
# Z = Z - hPsi_G^T * Z * hPsi_A
# Z = E_G^T * Z * E_A
# Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}, so that
# C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}
# the entire computation can be written as
# Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}
# - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A
# = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}
# - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A
# = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A
# - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A
# = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A
# This final expression is computed by the following two lines:
# Z = Z - P_G * Z * P_A^T
Z -= tf.matmul(P_G, tf.matmul(Z, P_A, transpose_b=True))
# Z = K_G^T * Z * K_A
Z = tf.matmul(K_G, tf.matmul(Z, K_A), transpose_a=True)
# Z = Z ./ (1*1^T - mu_G*mu_A^T)
# Be careful with the outer product. We don't want to accidentally
# make it an inner-product instead.
tmp = 1.0 - tf.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
# Prevent some numerical issues by setting any 0.0 eigs to 1.0
tmp += 1.0 * tf.cast(tf.equal(tmp, 0.0), dtype=tmp.dtype)
Z /= tmp
# We now perform the transpose/reverse version of the operations
# derived above, whose derivation from the original pseudo-code is
# analgous.
# Z = K_G * Z * K_A^T
Z = tf.matmul(K_G, tf.matmul(Z, K_A, transpose_b=True))
# Z = Z - P_G^T * Z * P_A
Z -= tf.matmul(P_G, tf.matmul(Z, P_A), transpose_a=True)
# Z = normalize (1/E[T]) * Z
# Note that this normalization is done because we compute the statistics
# by averaging, not summing, over time. (And the gradient is presumably
# summed over time, not averaged, and thus their scales are different.)
Z /= tf.cast(self._num_timesteps, Z.dtype)
# Convert back to the "batch_dim==0" orientation.
Z = tf.transpose(Z)
return utils.mat2d_to_layer_params(vector, Z)