def _process_data()

in kfac/python/ops/fisher_blocks.py [0:0]


  def _process_data(self, grads_list):
    """Process temporal/multi-use data into the format used by the factors.

    This function takes inputs and grads_lists data and processes it into
    one of the formats expected by the FisherFactor classes (depending on
    the value of the global configuration variable TOWER_STRATEGY).

    It accepts the data in one of two initial formats. The first possible
    format is where self._inputs is a list of list of Tensors. The first index
    is tower, the second is use/time-step. grads_list, meanwhile, is a list
    over sources of such lists of lists.

    The second possible data format is where self._inputs is a list of Tensors
    (over towers), where each Tensor has uses/times-steps folded into the batch
    dimension. i.e. they are Tensors of shape [num_uses * batch_size, ...],
    which represent reshapes of a Tensor of shape [num_uses, batch_size, ...].
    And similarly grads_list is a list over sources of such lists of Tensors.

    There are two possible formats which inputs and grads_list are transformed
    into.

    If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
    a single tensor (represented as a PartitionedTensor object) with all of
    the data from the towers, as well as the uses/time-steps, concatenated
    together. In this tensor the leading dimension is the batch and
    use/time-step dimensions folded together (with 'use' being the major of
    these two, so that the tensors can be thought of as reshapes of ones of
    shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
    tuple over sources of such tensors.

    If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
    tensors over towers. Each of these tensors has a similar format to
    the tensor produced by the "concat" option, except that each contains
    only the data from a single tower.  grads_list is similarly formatted
    into a tuple over sources of such tuples.

    Args:
      grads_list: grads_list in its initial format (see above).

    Returns:
      inputs: self._inputs transformed into the appropriate format (see
        above).
      grads_list: grads_list transformed into the appropriate format (see
        above).

    Raises:
      ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
      ValueError: If the given/initial format of self._inputs and grads_list
        isn't recognized, or doesn't agree with self._num_uses.
    """
    inputs = self._inputs

    # The first data format.
    if isinstance(inputs[0], (list, tuple)):

      num_uses = len(inputs[0])

      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of inputs.")
      else:
        self._num_uses = num_uses

      # Check that all mini-batches/towers have the same number of uses
      if not all(len(input_) == num_uses for input_ in inputs):
        raise ValueError("Length of inputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        inputs = tuple(zip(*inputs))

        # Flatten the two dimensions
        inputs = nest.flatten(inputs)

        # Merge everything together into a PartitionedTensor. We package it in
        # a singleton tuple since the factors will expect a list over towers
        inputs = (utils.PartitionedTensor(inputs),)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")
    # The second data format
    else:
      inputs = tuple(inputs)

    # Now we perform the analogous processing for grads_list

    # The first data format.
    if isinstance(grads_list[0][0], (list, tuple)):

      num_uses = len(grads_list[0][0])

      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of outputs, "
                         "or length of outputs is inconsistent with length of "
                         "inputs.")
      else:
        self._num_uses = num_uses

      if not all(len(grad) == num_uses for grads in grads_list
                 for grad in grads):
        raise ValueError("Length of outputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)

        # Flatten the two dimensions, leaving the leading dimension (source)
        # intact
        grads_list = tuple(nest.flatten(grads) for grads in grads_list)

        # Merge inner dimensions together into PartitionedTensors. We package
        # them in a singleton tuple since the factors will expect a list over
        # towers
        grads_list = tuple((utils.PartitionedTensor(grads),)
                           for grads in grads_list)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        grads_list = tuple(tuple(utils.PartitionedTensor(grad)
                                 for grad in grads)
                           for grads in grads_list)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")

    # The second data format.
    else:
      grads_list = tuple(tuple(grads) for grads in grads_list)

    if self._num_uses is None:
      raise ValueError("You must supply a value for the num_uses argument if "
                       "the number of uses cannot be inferred from inputs or "
                       "outputs arguments (e.g. if they are both given in the "
                       "single Tensor format, instead of as lists of Tensors.")

    return inputs, grads_list