def FPropFn()

in lingvo/core/gshard_layers.py [0:0]


  def FPropFn(self,
              theta,
              fn_name,
              *args,
              padded_per_stage_states,
              kwargs_no_batch=None,
              **kwargs):
    """Runs forward pass on a specified function."""
    p = self.params

    if p.unroll == 'always' or (self.do_eval and p.unroll == 'eval_only'):
      return self._unrolled_fprop(theta, *args, **kwargs)

    if p.per_stage_vars:
      all_iters = [theta['body_iter_%05d' % i] for i in range(p.num_stages)]
      theta_body = tf.nest.map_structure(lambda *t: tf.stack(list(t)),
                                         *all_iters)

      def _PassthroughVarSharding(x):
        split_dim = self._FindPerStageVarShardingDim(x.shape[1:])
        if split_dim < 0:
          x = xla_sharding.replicate(x, use_sharding_op=True)
        else:
          # The stacked theta has a leading dim, so we use split_dim + 1.
          x = gshard_utils.Split(
              x, split_dim + 1, p.num_stages, use_sharding_op=True)
        return x

      if p.shard_stages_1d:
        # Pass through the per-stage variables' sharding to the stacked theta.
        # Later, this will be resharded on the leading stage dim. This explicit
        # resharding makes sure that resharding happens after the concat, and
        # concat partitioning is trivial on the pass-through dim.
        theta_body = tf.nest.map_structure(_PassthroughVarSharding, theta_body)
    else:

      def _AnnotateTheta(x, var):
        return xla_sharding.copy_sharding(var, x, use_sharding_op=True)

      theta_body = tf.nest.map_structure(_AnnotateTheta, theta.body,
                                         self.vars.body)

    needs_microbatching = False
    if p.num_microbatches is None:
      num_microbatches = py_utils.Flatten(args)[0].get_shape().as_list()[0]
      if p.microbatch_size is not None:
        batch_size = num_microbatches
        assert batch_size % p.microbatch_size == 0
        num_microbatches = batch_size // p.microbatch_size
        needs_microbatching = True
    else:
      num_microbatches = p.num_microbatches
      needs_microbatching = True

    if needs_microbatching:

      def _ToMicrobatches(x):
        if not isinstance(x, (tf.Operation, tf.Tensor)):
          return x
        x_shape = py_utils.GetShape(x)
        assert x_shape[0] % num_microbatches == 0
        # We first put num_microbatches in the inner dimension then transpose
        # it. This allows the sharding on the batch (if any) to be propagated
        # to the microbatch dimension. We cannot shard the num_microbatches
        # dimension, since it's indexed by the loop iteration.
        reshaped = tf.reshape(
            x, [x_shape[0] // num_microbatches, num_microbatches] + x_shape[1:])
        return tf.transpose(reshaped,
                            [1, 0] + list(range(2, len(reshaped.shape))))

      args = tf.nest.map_structure(_ToMicrobatches, args)
      kwargs = tf.nest.map_structure(_ToMicrobatches, kwargs)

    def _MaybeReplicateNumMicrobatches(x):
      # Mark the num_microbatches dim replicated.
      if not isinstance(x, (tf.Operation, tf.Tensor)):
        return x
      if p.shard_stages_1d:
        return gshard_utils.Replicate(x)
      if p.pipeline_stage_mesh_dim is not None:
        # Partially specify that only dim 0 is replicated.
        return gshard_utils.MeshSplit(
            x,
            p.device_mesh, [-1] * len(x.shape),
            unspecified_dims=list(range(1, len(x.shape))))
      return x

    # Replicate the input as the layer is only sharded on the stage dimension.
    args = tf.nest.map_structure(_MaybeReplicateNumMicrobatches, args)
    kwargs = tf.nest.map_structure(_MaybeReplicateNumMicrobatches, kwargs)

    if p.shard_stages_1d:

      def _SplitStages(x):
        return gshard_utils.Split(x, 0, p.num_stages)

      theta_body = tf.nest.map_structure(_SplitStages, theta_body)

    # Adds a `stages` dimension after the leading num_microbatches to the inputs
    # which will be sharded. Also pad the leading num_microbatches dimension by
    # num_stages - 1 to match loop iteration count, which corresponds to the
    # bubbles between forward and backward passes.
    #
    # Inputs are not the loop state: they are not changed during the loop. The
    # state (shifting buffer) does not have a num_microbatches dimension.
    def _PadInput(inp):
      if not isinstance(inp, (tf.Operation, tf.Tensor)):
        return inp
      # Takes input tensor of shape [num_microbatches, ...] and returns padded
      # tensor of shape [num_iterations_with_bubbles, num_stages,  ...],
      # where num_stages is a new dimension.
      with_new_dim = tf.expand_dims(inp, 1)
      padded = self._PadMicrobatchesInternal(with_new_dim, pad_stages=True)
      assert len(padded.shape) == len(inp.shape) + 1
      assert padded.shape[1] == p.num_stages
      return padded

    padded_inputs = tf.nest.map_structure(_PadInput, args)
    padded_shapes = tf.nest.map_structure(
        lambda x: None if x is None else x.shape, padded_inputs)
    remove_first_dim = lambda x: None if x is None else x[1:]
    state_shapes = tf.nest.map_structure(remove_first_dim, padded_shapes)

    def _ArgsToState(arg_list):
      """Returns a NestedMap from a list of FProp args."""
      state = py_utils.NestedMap()
      # Maintains a mapping from arg_idx to tensor. states cannot contain None
      # tensors.
      for idx in range(len(padded_inputs)):
        if isinstance(arg_list[idx], py_utils.NestedMap):
          # Make sure each value in the NestedMap is a tensor.
          if not all(isinstance(t, tf.Tensor) for t in arg_list[idx].Flatten()):
            raise ValueError(
                'Each value in the input NestedMap must be a tensor.')
        if arg_list[idx] is not None:
          state['_s{}'.format(idx)] = arg_list[idx]
      return state

    def _StateToArgs(state, shapes):
      """Returns a list of FProp args from a NestedMap."""
      arg_list = []
      for idx in range(len(padded_inputs)):
        attr = '_s{}'.format(idx)
        arg_list.append(state[attr] if attr in state else None)
        tf.nest.map_structure(lambda x, s: x.set_shape(s), arg_list[-1],
                              shapes[idx])
      return arg_list

    self._tpu_summary_structure = None

    def _CellFn(theta, state0, inputs_and_per_stage_states):
      """Recurrent cell function wrapper of body.FProp."""
      inputs = inputs_and_per_stage_states.inputs
      per_stage_states = inputs_and_per_stage_states.per_stage_states
      tf.nest.map_structure(lambda x, y: x.set_shape(y.shape[1:]),
                            per_stage_states, padded_per_stage_states)
      state0.iteration.set_shape([])
      state0.aux_loss.set_shape([])

      def _SelectInput(state, inp):
        in_mask = tf.equal(tf.range(p.num_stages), 0)
        if p.circular_repeat == 1:
          # The state is aligned to previous stage. We shift it to the right by
          # 1 stage. If the stage dimension is partitioned in GShard, this will
          # cause a collective-permute being added.
          padding = [[1, 0]] + [[0, 0]] * (len(state.shape) - 1)
          shifted_state = tf.pad(state, padding)[0:p.num_stages, ...]
        else:
          # Rotate the circular buffer. If the stage dimension is partitioned in
          # GShard, this will cause a collective-permute being added.
          shifted_state = tf.concat([state[-1:], state[:-1]], axis=0)
          in_segment_offset = tf.math.mod(state0.iteration,
                                          p.circular_repeat * p.num_stages)
          in_mask = tf.logical_and(in_mask,
                                   tf.less(in_segment_offset, p.num_stages))

        in_mask = tf.reshape(in_mask,
                             [p.num_stages] + [1] * (len(inp.shape) - 1))
        return tf.where(
            tf.broadcast_to(in_mask, shifted_state.shape),
            tf.cast(inp, shifted_state.dtype), shifted_state)

      selected_inputs = tf.nest.map_structure(
          _SelectInput, _StateToArgs(state0.args, state_shapes),
          _StateToArgs(inputs, state_shapes))

      # Restore non-trainable vars to state0, because it can be called in the
      # backward pass.
      assigns = tf.nest.map_structure(lambda v, s: v.assign(s),
                                      self._non_trainable_vars,
                                      state0.non_trainable_vars)

      def _BodyFPropWithAuxLoss():
        with py_utils.AuxLossContext(reentrant=True) as al_ctx:
          with py_utils.TpuSummaryTensorContext():
            fprop_outputs, ctrl = self.BodyFProp(
                theta,
                fn_name,
                state0.iteration,
                num_microbatches,
                *selected_inputs,
                *per_stage_states,
                kwargs_no_batch=kwargs_no_batch,
                **kwargs)
            context_tensors = py_utils.NestedMap(
                tpu_summary_tensors=py_utils.GetTpuSummaryTensors(),
                aux_losses=al_ctx.aux_losses)
        return fprop_outputs, ctrl, context_tensors

      if assigns:
        # Group the dependencies into a single no_op to avoid quadratic number
        # of control edges.
        with tf.control_dependencies(assigns):
          ctrl_before = tf.no_op()
        with tf.control_dependencies([ctrl_before]):
          fprop_outputs, ctrl, context_tensors = _BodyFPropWithAuxLoss()
      else:
        fprop_outputs, ctrl, context_tensors = _BodyFPropWithAuxLoss()
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(selected_inputs) + len(per_stage_states)

      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(
          args=_ArgsToState(fprop_outputs[:len(selected_inputs)]),
          per_stage_states=list(fprop_outputs[len(selected_inputs):]),
          iteration=state0.iteration + tf.constant(1, dtype=tf.int32))

      # v and v0 are the new and old values for each stage with leading dim
      # num_stages. Selects v if it's a valid iteration and v0 if it's a bubble.
      def _NewValueIfValidIter(v, v0):
        mb_id, _ = self._MicrobatchAndRepeatIDs(state0.iteration)
        valid_iter = tf.logical_and(
            tf.less(mb_id, num_microbatches),
            tf.greater_equal(state0.iteration, tf.range(p.num_stages)))
        with tf.control_dependencies([ctrl]):
          v1 = tf.identity(v)
        return tf.where(
            tf.broadcast_to(
                tf.reshape(valid_iter,
                           [p.num_stages] + [1] * (len(v.shape) - 1)), v.shape),
            v1, v0)

      # Pass state0.non_trainable_vars or updated values depending on whether
      # it is a bubble iteration.
      state1.non_trainable_vars = tf.nest.map_structure(
          _NewValueIfValidIter, self._non_trainable_vars,
          state0.non_trainable_vars)

      if context_tensors.aux_losses:
        context_tensors.aux_losses = tf.add_n([
            tf.cast(l, state0.aux_loss.dtype)
            for l in context_tensors.aux_losses
        ])
        state1.aux_loss = state0.aux_loss + tf.reduce_sum(
            _NewValueIfValidIter(context_tensors.aux_losses,
                                 tf.zeros_like(context_tensors.aux_losses)))
      else:
        state1.aux_loss = state0.aux_loss
      if self._non_trainable_vars:
        # Skip summary tensors when there are non-trainable vars. Recurrent()
        # uses reflection to figure out the signature, but that causes problems
        # for stateful computation.
        extras = py_utils.NestedMap()
      else:
        self._tpu_summary_structure = tf.nest.map_structure(
            lambda _: None, context_tensors.tpu_summary_tensors)
        # Set the value/weight of the summary tensors to 0 for bubble
        # iterations.
        context_tensors.tpu_summary_tensors = tf.nest.map_structure(
            lambda x: _NewValueIfValidIter(x, tf.zeros_like(x)),
            context_tensors.tpu_summary_tensors)
        extras = py_utils.NestedMap(
            tpu_summary_tensors=tf.nest.flatten(
                context_tensors.tpu_summary_tensors))
      return state1, extras

    with tf.name_scope(p.name):
      inputs_nmap = _ArgsToState(padded_inputs)

      def _CreateInitState(inp):
        return tf.zeros(py_utils.GetShape(inp)[1:], dtype=inp.dtype)

      # Add FProp arg list to state0.
      state0 = py_utils.NestedMap(
          args=tf.nest.map_structure(_CreateInitState, inputs_nmap),
          per_stage_states=tf.nest.map_structure(_CreateInitState,
                                                 padded_per_stage_states),
          iteration=tf.constant(0, dtype=tf.int32),
          aux_loss=tf.constant(0, dtype=tf.float32),
          non_trainable_vars=tf.nest.map_structure(tf.identity,
                                                   self._non_trainable_vars))
      final_non_trainable_var_values = None

      def _RestoreVarsToFinal():
        assert final_non_trainable_var_values is not None
        assigns = tf.nest.map_structure(lambda v, x: v.assign(x),
                                        self._non_trainable_vars,
                                        final_non_trainable_var_values)
        with tf.control_dependencies(assigns):
          return [tf.no_op()]

      # Runs body.FProp k times using Recurrent where k = dim 0 of inputs_nmap.
      accum, outputs, accum_extras = recurrent.Recurrent(
          theta=theta_body,
          state0=state0,
          inputs=py_utils.NestedMap(
              inputs=inputs_nmap, per_stage_states=padded_per_stage_states),
          cell_fn=_CellFn,
          # Use {} to avoid reflection call that affects non trainable vars.
          extras={} if self._non_trainable_vars else None,
          allow_implicit_capture=p.allow_implicit_capture,
          allowed_tensor_captures=self._non_trainable_vars + [
              x for x in py_utils.Flatten([kwargs, kwargs_no_batch])
              if isinstance(x, (tf.Operation, tf.Tensor))
          ],
          backward_cleanup=(_RestoreVarsToFinal
                            if self._non_trainable_vars else None),
          return_acc_extras=True)

      # Retrieves fprop outputs.
      def _ExtractLastStage(outp):
        if p.circular_repeat == 1:
          return outp[p.num_stages - 1:, -1, ...]
        else:
          # See the class documuentation for circular pipeline.
          bubble_removed = outp[p.num_stages - 1:, -1, ...]
          num_segments = (num_microbatches + p.num_stages - 1) // p.num_stages
          segmented = tf.reshape(
              bubble_removed,
              [num_segments, p.circular_repeat, p.num_stages] + outp.shape[2:])
          return tf.reshape(segmented[:, -1,
                                      ...], [num_segments * p.num_stages] +
                            outp.shape[2:])[:num_microbatches]

      final_non_trainable_var_values = outputs.non_trainable_vars
      output_tensors = tf.nest.map_structure(
          _ExtractLastStage, _StateToArgs(accum.args, padded_shapes))
      output_per_stage_states = accum.per_stage_states
      if self._non_trainable_vars:
        with tf.control_dependencies(_RestoreVarsToFinal()):
          output_tensors = tf.nest.map_structure(tf.identity, output_tensors)
          output_per_stage_states = tf.nest.map_structure(
              tf.identity, output_per_stage_states)

      aux_loss_context = py_utils.AuxLossContext.Current()
      if aux_loss_context:
        if p.aux_loss_microbatch_accumulation == 'mean':
          outputs.aux_loss = tf.div(outputs.aux_loss, num_microbatches)
        aux_loss_context.AddLoss(outputs.aux_loss)

      if self._tpu_summary_structure is not None:
        tpu_summary_tensors = tf.nest.pack_sequence_as(
            self._tpu_summary_structure, accum_extras.tpu_summary_tensors)
        for key, (value, weight) in tpu_summary_tensors.items():
          for stage_id in range(p.num_stages):
            v, w = py_utils.WeightedAvg(value[:, stage_id, ...],
                                        weight[:, stage_id, ...])
            py_utils.AddTpuSummaryTensor('%s/stage_%s' % (key, stage_id), v, w)

      output_tensors = tf.nest.map_structure(_MaybeReplicateNumMicrobatches,
                                             output_tensors)
      if needs_microbatching:

        def _ToBatches(x):
          x_shape = py_utils.GetShape(x)
          transposed = tf.transpose(x, [1, 0] + list(range(2, len(x_shape))))
          return tf.reshape(transposed,
                            [num_microbatches * x_shape[1]] + x_shape[2:])

        output_tensors = tf.nest.map_structure(_ToBatches, output_tensors)
      output_tensors += output_per_stage_states
      return output_tensors[0] if len(output_tensors) == 1 else tuple(
          output_tensors)