in kfac/python/ops/fisher_blocks.py [0:0]
def _process_data(self, grads_list):
"""Process temporal/multi-use data into the format used by the factors.
This function takes inputs and grads_lists data and processes it into
one of the formats expected by the FisherFactor classes (depending on
the value of the global configuration variable TOWER_STRATEGY).
It accepts the data in one of two initial formats. The first possible
format is where self._inputs is a list of list of Tensors. The first index
is tower, the second is use/time-step. grads_list, meanwhile, is a list
over sources of such lists of lists.
The second possible data format is where self._inputs is a list of Tensors
(over towers), where each Tensor has uses/times-steps folded into the batch
dimension. i.e. they are Tensors of shape [num_uses * batch_size, ...],
which represent reshapes of a Tensor of shape [num_uses, batch_size, ...].
And similarly grads_list is a list over sources of such lists of Tensors.
There are two possible formats which inputs and grads_list are transformed
into.
If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
a single tensor (represented as a PartitionedTensor object) with all of
the data from the towers, as well as the uses/time-steps, concatenated
together. In this tensor the leading dimension is the batch and
use/time-step dimensions folded together (with 'use' being the major of
these two, so that the tensors can be thought of as reshapes of ones of
shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
tuple over sources of such tensors.
If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
tensors over towers. Each of these tensors has a similar format to
the tensor produced by the "concat" option, except that each contains
only the data from a single tower. grads_list is similarly formatted
into a tuple over sources of such tuples.
Args:
grads_list: grads_list in its initial format (see above).
Returns:
inputs: self._inputs transformed into the appropriate format (see
above).
grads_list: grads_list transformed into the appropriate format (see
above).
Raises:
ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
ValueError: If the given/initial format of self._inputs and grads_list
isn't recognized, or doesn't agree with self._num_uses.
"""
inputs = self._inputs
# The first data format.
if isinstance(inputs[0], (list, tuple)):
num_uses = len(inputs[0])
if self._num_uses is not None and self._num_uses != num_uses:
raise ValueError("num_uses argument doesn't match length of inputs.")
else:
self._num_uses = num_uses
# Check that all mini-batches/towers have the same number of uses
if not all(len(input_) == num_uses for input_ in inputs):
raise ValueError("Length of inputs argument is inconsistent across "
"towers.")
if fisher_factors.TOWER_STRATEGY == "concat":
# Reverse the tower and use/time-step indices, so that use is now first,
# and towers is second
inputs = tuple(zip(*inputs))
# Flatten the two dimensions
inputs = nest.flatten(inputs)
# Merge everything together into a PartitionedTensor. We package it in
# a singleton tuple since the factors will expect a list over towers
inputs = (utils.PartitionedTensor(inputs),)
elif fisher_factors.TOWER_STRATEGY == "separate":
# Merge together the uses/time-step dimension into PartitionedTensors,
# but keep the leading dimension (towers) intact for the factors to
# process individually.
inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
# The second data format
else:
inputs = tuple(inputs)
# Now we perform the analogous processing for grads_list
# The first data format.
if isinstance(grads_list[0][0], (list, tuple)):
num_uses = len(grads_list[0][0])
if self._num_uses is not None and self._num_uses != num_uses:
raise ValueError("num_uses argument doesn't match length of outputs, "
"or length of outputs is inconsistent with length of "
"inputs.")
else:
self._num_uses = num_uses
if not all(len(grad) == num_uses for grads in grads_list
for grad in grads):
raise ValueError("Length of outputs argument is inconsistent across "
"towers.")
if fisher_factors.TOWER_STRATEGY == "concat":
# Reverse the tower and use/time-step indices, so that use is now first,
# and towers is second
grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
# Flatten the two dimensions, leaving the leading dimension (source)
# intact
grads_list = tuple(nest.flatten(grads) for grads in grads_list)
# Merge inner dimensions together into PartitionedTensors. We package
# them in a singleton tuple since the factors will expect a list over
# towers
grads_list = tuple((utils.PartitionedTensor(grads),)
for grads in grads_list)
elif fisher_factors.TOWER_STRATEGY == "separate":
# Merge together the uses/time-step dimension into PartitionedTensors,
# but keep the leading dimension (towers) intact for the factors to
# process individually.
grads_list = tuple(tuple(utils.PartitionedTensor(grad)
for grad in grads)
for grads in grads_list)
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
# The second data format.
else:
grads_list = tuple(tuple(grads) for grads in grads_list)
if self._num_uses is None:
raise ValueError("You must supply a value for the num_uses argument if "
"the number of uses cannot be inferred from inputs or "
"outputs arguments (e.g. if they are both given in the "
"single Tensor format, instead of as lists of Tensors.")
return inputs, grads_list