def _instantiate_factors()

in kfac/python/ops/estimator.py [0:0]


  def _instantiate_factors(self):
    """Instantiates FisherFactors' variables.

    Raises:
      ValueError: If estimation_mode was improperly specified at construction.
    """
    blocks = self.blocks
    tensors_to_compute_grads = [
        block.tensors_to_compute_grads() for block in blocks
    ]

    if self._compute_params_stats:
      tensors_to_compute_grads = tensors_to_compute_grads + self.variables

    try:
      grads_lists = self._gradient_fns[self._estimation_mode](
          tensors_to_compute_grads)
    except KeyError:
      raise ValueError("Unrecognized value {} for estimation_mode.".format(
          self._estimation_mode))

    if any(grad is None for grad in nest.flatten(grads_lists)):
      tensors_flat = nest.flatten(tensors_to_compute_grads)
      grads_flat = nest.flatten(grads_lists)
      bad_tensors = tuple(
          tensor for tensor, grad in zip(tensors_flat, grads_flat)
          if grad is None)
      bad_string = ""
      for tensor in bad_tensors:
        bad_string += "\t{}\n".format(tensor)
      raise ValueError("It looks like you registered one of more tensors that "
                       "the registered loss/losses don't depend on. (These "
                       "returned None from tf.gradients.) The tensors were:"
                       "\n\n" + bad_string)

    if self._compute_params_stats:
      idx = len(blocks)
      params_stats_unnorm = tuple(tf.add_n(grad_list)
                                  for grad_list in grads_lists[idx:])

      scalar = 1. / tf.cast(self._batch_size,
                            dtype=params_stats_unnorm[0].dtype)
      params_stats = utils.sprod(scalar, params_stats_unnorm)

      # batch_size should be the per-replica batch size and thus we do a
      # cross-replica mean instead of a sum here
      self._params_stats = tuple(utils.all_average(tensor)
                                 for tensor in params_stats)

      grads_lists = grads_lists[:idx]

    for grads_list, block in zip(grads_lists, blocks):
      block.instantiate_factors(grads_list, self.damping)