def _BodyFPropInternal()

in lingvo/core/gshard_layers.py [0:0]


  def _BodyFPropInternal(self,
                         theta,
                         fn_name,
                         iteration,
                         num_microbatches,
                         *args,
                         kwargs_no_batch=None,
                         **kwargs):
    p = self.params
    wrappers = []

    # Wrap non-trainable vars with VarWrapperTrackAssign to track control
    # dependencies.
    def _WrapWithTracking(v):
      if v.trainable:
        return v
      wrapper = var_tmp_wrappers.VarWrapperTrackAssign(v)
      wrappers.append(wrapper)
      return wrapper

    def _BodyFProp(x):
      with self.TransformVarsTempContext(_WrapWithTracking):
        # Create an inner aux loss context, and extract the aux losses as extra
        # outputs so that the function can be vectorized.
        with py_utils.AuxLossContext(reentrant=True) as al_ctx:
          with py_utils.TpuSummaryTensorContext():
            if p.per_stage_vars:
              outs = getattr(self.body_iter_00000, fn_name)(x.theta, *x.args,
                                                            **x.kwargs)
            else:
              outs = getattr(self.body, fn_name)(x.theta, *x.args, **x.kwargs)
            context_tensors = py_utils.NestedMap(
                tpu_summary_tensors=py_utils.GetTpuSummaryTensors(),
                aux_losses=al_ctx.aux_losses)
            if not wrappers:
              return outs, tf.zeros([], dtype=tf.int32), context_tensors
            with tf.control_dependencies(
                [w.control_after_assigns() for w in wrappers]):
              control_out = tf.zeros([], dtype=tf.int32)
            return outs, control_out, context_tensors

    if p.stage_parallel_body is not None:
      for key, val in (kwargs_no_batch or {}).items():
        kwargs[key] = val
      return _BodyFProp(theta, *args, **kwargs)

    theta_args = py_utils.NestedMap(theta=theta, args=args)

    if p.shard_stages_1d:
      device_mesh = np.arange(p.num_stages)
      stage_mesh_dim = 0
    elif p.pipeline_stage_mesh_dim is not None:
      device_mesh = p.device_mesh
      stage_mesh_dim = p.pipeline_stage_mesh_dim
    else:
      device_mesh = None

    if device_mesh is not None:
      # Each stage should have its own seed.
      seeds = tf.stack([py_utils.GetIncStepSeed() for _ in range(p.num_stages)])
      seeds = gshard_utils.Replicate(seeds)

      def _ToManual(x, var=None):
        if not isinstance(x, (tf.Operation, tf.Tensor, tf.Variable)):
          return x
        if var is None:
          sharding = gshard_utils.GetMeshSplitSharding(
              device_mesh, [stage_mesh_dim] + [-1] *
              (len(x.shape) - 1)).proto.SerializeToString()
          # Partially specify that only dim 0 is annotated with sharding.
          unspecified_dims = list(range(1, len(x.shape)))
        else:
          sharding = xla_sharding.get_op_sharding(var.op)
          unspecified_dims = None
        to_manual = xla_sharding.auto_to_manual_spmd_partition(
            x, sharding, single_dim=0, unspecified_dims=unspecified_dims)
        return tf.squeeze(to_manual, 0)

      if p.per_stage_vars:
        manual_theta = tf.nest.map_structure(_ToManual, theta_args.theta)
      else:
        manual_theta = tf.nest.map_structure(_ToManual, theta_args.theta,
                                             self.body.vars)
      one_stage_theta_args = py_utils.NestedMap(
          theta=manual_theta,
          args=tf.nest.map_structure(_ToManual, theta_args.args))
      py_utils.ResetStepSeed(_ToManual(seeds))

      def _ToManualReplicate(x):
        if not isinstance(x, (tf.Operation, tf.Tensor)):
          return x
        if p.shard_stages_1d:
          sharding = xla_sharding.Sharding.replicate()
          return xla_sharding.auto_to_manual_spmd_partition(
              x, sharding.proto.SerializeToString())
        else:
          # We do a broadcast first, then we can reuse _ToManual().
          x = tf.broadcast_to(x, [p.num_stages] + x.shape)
          return _ToManual(x)

      stage_id = _ToManual(tf.range(p.num_stages))

      microbatch_ids, repeat_ids = self._MicrobatchAndRepeatIDs(iteration)
      microbatch_id = _ToManual(microbatch_ids)
      repeat_id = _ToManual(repeat_ids)

      if p.circular_repeat > 1:
        one_stage_theta_args.theta = tf.nest.map_structure(
            lambda x: x[repeat_id], one_stage_theta_args.theta)

      microbatch_id = tf.minimum(microbatch_id, num_microbatches - 1)

      def _KwargSlice(x):
        if not isinstance(x, (tf.Operation, tf.Tensor)):
          return x
        return _ToManualReplicate(x)[microbatch_id]

      one_stage_theta_args.kwargs = tf.nest.map_structure(_KwargSlice, kwargs)
      for key, val in (kwargs_no_batch or {}).items():
        one_stage_theta_args.kwargs[key] = tf.nest.map_structure(
            _ToManualReplicate, val)

      # Wrap non-trainable vars with StackedVarWrapperWithManualSharding, in
      # case they are accessed directly in FProp (e.g., batch norm vars).
      def _WrapWithManual(v):
        if v.trainable:
          return v
        return var_tmp_wrappers.StackedVarWrapperWithManualSharding(v)

      with self.TransformVarsTempContext(_WrapWithManual):
        # Step seed should be incremented by p.num_stages.
        with py_utils.StepSeedIncrementContext(p.num_stages):
          with py_utils.GlobalStepContext(
              _ToManualReplicate(py_utils.GetGlobalStep())):
            # If there are any internal annotations in the stage, they will be
            # subgrouped with manual partitioning on stage_mesh_dim.
            with gshard_utils.ManualMeshDimContext(stage_mesh_dim):
              one_stage_outputs, control_out, context_tensors = _BodyFProp(
                  one_stage_theta_args)

      def _ToAuto(x):
        if not isinstance(x, (tf.Operation, tf.Tensor)):
          return x
        full_shape = [p.num_stages] + x.shape
        unspecified_dims = list(range(1, len(full_shape)))
        sharding = gshard_utils.GetMeshSplitSharding(
            device_mesh, [stage_mesh_dim] + [-1] * len(x.shape))
        x = tf.expand_dims(x, 0)
        return xla_sharding.manual_to_auto_spmd_partition(
            x,
            sharding.proto.SerializeToString(),
            full_shape=full_shape,
            single_dim=0,
            unspecified_dims=unspecified_dims)

      # Reset step seed to the last stage's final seed.
      py_utils.ResetStepSeed(_ToAuto(py_utils.GetStepSeed())[-1])
      # Convert aux losses to per-stage vector losses.
      outputs = tf.nest.map_structure(_ToAuto, one_stage_outputs)
      context_tensors = tf.nest.map_structure(_ToAuto, context_tensors)
      return outputs, control_out, context_tensors

    else:
      stage_id = tf.range(p.num_stages)
      microbatch_id = tf.maximum(iteration - stage_id,
                                 tf.zeros([p.num_stages], dtype=stage_id.dtype))

      def _KwargSlice(x):
        if not isinstance(x, (tf.Operation, tf.Tensor)):
          return x
        return tf.gather(x, microbatch_id)

      theta_args.kwargs = tf.nest.map_structure(_KwargSlice, kwargs)
      for key, val in (kwargs_no_batch or {}).items():
        theta_args.kwargs[key] = val
      return tf.vectorized_map(
          _BodyFProp, theta_args, fallback_to_while_loop=False)