def FProp()

in lingvo/core/gpipe.py [0:0]


  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if self.do_eval and self.cluster.num_devices_per_split == 1:
      outputs = copy.copy(args)
      for (name, l) in self._before_layers + self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tensors.
    input_shapes = self._get_input_shapes(*args)
    state_dtype = self._get_state_dtype(*args)
    state_shapes = self._CalculateOutputShapes(input_shapes)
    tf.logging.info('state_shapes={}'.format(state_shapes))

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0

        def _FPropInputSetShape(name, t_shape):
          if t_shape is None:
            return None
          inputs[name].set_shape(t_shape.ToTensorShape().as_list())
          return inputs[name]

        if p.nested_map_fprop:
          # pylint: disable=protected-access
          fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
          # pylint: enable=protected-access
        else:
          fprop_inputs = []
          for input_idx, input_shape in enumerate(state_shapes[i]):
            name = 's{}'.format(input_idx)
            fprop_inputs.append(_FPropInputSetShape(name, input_shape))

        with py_utils.RemoveAssertContext(remove=True):
          with CellFnFPropOpReplacementWrapper():
            tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
            mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
            SetOverWriteGlobalStep(mb_tensor)
            _, cell = self._cells[i]
            fprop_inputs = _ToTuple(fprop_inputs)
            outputs = cell.FProp(theta, *fprop_inputs)

        if p.nested_map_fprop:
          assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
          state1 = outputs.Filter(lambda x: x is not None)
        else:
          state1 = py_utils.NestedMap()
          outputs = _ToTuple(outputs)
          assert len(outputs) == len(state_shapes[i + 1])
          for output_idx in range(len(outputs)):
            if outputs[output_idx] is not None:
              name = 's{}'.format(output_idx)
              state1[name] = outputs[output_idx]
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])

      def _TfZeros(t_shape):
        if t_shape is None:
          return None
        return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)

      if p.nested_map_fprop:
        init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
        init_state = init_state.Filter(lambda x: x is not None)
      else:
        init_state = py_utils.NestedMap()
        for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
          state = _TfZeros(state)
          if state is not None:
            name = 's{}'.format(output_idx)
            init_state[name] = state
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      init_states.append(init_state)

      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = _ToTuple(args)
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)

      def _StackAndSplit(x):
        # Split tensors into microbatches.
        if x is None:
          return None
        return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))

      if p.nested_map_fprop:
        inputs = py_utils.Transform(_StackAndSplit, previous[0])
        inputs = inputs.Filter(lambda x: x is not None)
      else:
        inputs = py_utils.NestedMap()
        for output_idx, output_tensor in enumerate(previous):
          output_tensor = _StackAndSplit(output_tensor)
          if output_tensor is not None:
            name = 's{}'.format(output_idx)
            inputs[name] = output_tensor
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])
    tf.logging.info('pipeline input = {}'.format(inputs))
    output_state, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):

      def _ReshapeRetVal(name, t_shape):
        """Restore shape for tensors in microbatches."""
        if t_shape is None:
          return None
        output_tensor = output_state[name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        output_shape = t_shape.ToTensorShape().as_list()
        output_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, output_shape)
        return output_tensor

      # Construct the final return values from output_state.
      if p.nested_map_fprop:
        # pylint: disable=protected-access
        output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
        # pylint: enable=protected-access
      else:
        output_tensors = []
        for output_idx, state_shape in enumerate(state_shapes[-1]):
          output_name = 's{}'.format(output_idx)
          output_tensor = _ReshapeRetVal(output_name, state_shape)
          output_tensors.append(output_tensor)
        if len(output_tensors) == 1:
          output_tensors = output_tensors[0]
        else:
          output_tensors = tuple(output_tensors)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      return output_tensors