in kfac/python/ops/estimator.py [0:0]
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
Raises:
ValueError: If estimation_mode was improperly specified at construction.
"""
blocks = self.blocks
tensors_to_compute_grads = [
block.tensors_to_compute_grads() for block in blocks
]
if self._compute_params_stats:
tensors_to_compute_grads = tensors_to_compute_grads + self.variables
try:
grads_lists = self._gradient_fns[self._estimation_mode](
tensors_to_compute_grads)
except KeyError:
raise ValueError("Unrecognized value {} for estimation_mode.".format(
self._estimation_mode))
if any(grad is None for grad in nest.flatten(grads_lists)):
tensors_flat = nest.flatten(tensors_to_compute_grads)
grads_flat = nest.flatten(grads_lists)
bad_tensors = tuple(
tensor for tensor, grad in zip(tensors_flat, grads_flat)
if grad is None)
bad_string = ""
for tensor in bad_tensors:
bad_string += "\t{}\n".format(tensor)
raise ValueError("It looks like you registered one of more tensors that "
"the registered loss/losses don't depend on. (These "
"returned None from tf.gradients.) The tensors were:"
"\n\n" + bad_string)
if self._compute_params_stats:
idx = len(blocks)
params_stats_unnorm = tuple(tf.add_n(grad_list)
for grad_list in grads_lists[idx:])
scalar = 1. / tf.cast(self._batch_size,
dtype=params_stats_unnorm[0].dtype)
params_stats = utils.sprod(scalar, params_stats_unnorm)
# batch_size should be the per-replica batch size and thus we do a
# cross-replica mean instead of a sum here
self._params_stats = tuple(utils.all_average(tensor)
for tensor in params_stats)
grads_lists = grads_lists[:idx]
for grads_list, block in zip(grads_lists, blocks):
block.instantiate_factors(grads_list, self.damping)