in lingvo/core/gshard_layers.py [0:0]
def _BodyFPropInternal(self,
theta,
fn_name,
iteration,
num_microbatches,
*args,
kwargs_no_batch=None,
**kwargs):
p = self.params
wrappers = []
# Wrap non-trainable vars with VarWrapperTrackAssign to track control
# dependencies.
def _WrapWithTracking(v):
if v.trainable:
return v
wrapper = var_tmp_wrappers.VarWrapperTrackAssign(v)
wrappers.append(wrapper)
return wrapper
def _BodyFProp(x):
with self.TransformVarsTempContext(_WrapWithTracking):
# Create an inner aux loss context, and extract the aux losses as extra
# outputs so that the function can be vectorized.
with py_utils.AuxLossContext(reentrant=True) as al_ctx:
with py_utils.TpuSummaryTensorContext():
if p.per_stage_vars:
outs = getattr(self.body_iter_00000, fn_name)(x.theta, *x.args,
**x.kwargs)
else:
outs = getattr(self.body, fn_name)(x.theta, *x.args, **x.kwargs)
context_tensors = py_utils.NestedMap(
tpu_summary_tensors=py_utils.GetTpuSummaryTensors(),
aux_losses=al_ctx.aux_losses)
if not wrappers:
return outs, tf.zeros([], dtype=tf.int32), context_tensors
with tf.control_dependencies(
[w.control_after_assigns() for w in wrappers]):
control_out = tf.zeros([], dtype=tf.int32)
return outs, control_out, context_tensors
if p.stage_parallel_body is not None:
for key, val in (kwargs_no_batch or {}).items():
kwargs[key] = val
return _BodyFProp(theta, *args, **kwargs)
theta_args = py_utils.NestedMap(theta=theta, args=args)
if p.shard_stages_1d:
device_mesh = np.arange(p.num_stages)
stage_mesh_dim = 0
elif p.pipeline_stage_mesh_dim is not None:
device_mesh = p.device_mesh
stage_mesh_dim = p.pipeline_stage_mesh_dim
else:
device_mesh = None
if device_mesh is not None:
# Each stage should have its own seed.
seeds = tf.stack([py_utils.GetIncStepSeed() for _ in range(p.num_stages)])
seeds = gshard_utils.Replicate(seeds)
def _ToManual(x, var=None):
if not isinstance(x, (tf.Operation, tf.Tensor, tf.Variable)):
return x
if var is None:
sharding = gshard_utils.GetMeshSplitSharding(
device_mesh, [stage_mesh_dim] + [-1] *
(len(x.shape) - 1)).proto.SerializeToString()
# Partially specify that only dim 0 is annotated with sharding.
unspecified_dims = list(range(1, len(x.shape)))
else:
sharding = xla_sharding.get_op_sharding(var.op)
unspecified_dims = None
to_manual = xla_sharding.auto_to_manual_spmd_partition(
x, sharding, single_dim=0, unspecified_dims=unspecified_dims)
return tf.squeeze(to_manual, 0)
if p.per_stage_vars:
manual_theta = tf.nest.map_structure(_ToManual, theta_args.theta)
else:
manual_theta = tf.nest.map_structure(_ToManual, theta_args.theta,
self.body.vars)
one_stage_theta_args = py_utils.NestedMap(
theta=manual_theta,
args=tf.nest.map_structure(_ToManual, theta_args.args))
py_utils.ResetStepSeed(_ToManual(seeds))
def _ToManualReplicate(x):
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
if p.shard_stages_1d:
sharding = xla_sharding.Sharding.replicate()
return xla_sharding.auto_to_manual_spmd_partition(
x, sharding.proto.SerializeToString())
else:
# We do a broadcast first, then we can reuse _ToManual().
x = tf.broadcast_to(x, [p.num_stages] + x.shape)
return _ToManual(x)
stage_id = _ToManual(tf.range(p.num_stages))
microbatch_ids, repeat_ids = self._MicrobatchAndRepeatIDs(iteration)
microbatch_id = _ToManual(microbatch_ids)
repeat_id = _ToManual(repeat_ids)
if p.circular_repeat > 1:
one_stage_theta_args.theta = tf.nest.map_structure(
lambda x: x[repeat_id], one_stage_theta_args.theta)
microbatch_id = tf.minimum(microbatch_id, num_microbatches - 1)
def _KwargSlice(x):
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
return _ToManualReplicate(x)[microbatch_id]
one_stage_theta_args.kwargs = tf.nest.map_structure(_KwargSlice, kwargs)
for key, val in (kwargs_no_batch or {}).items():
one_stage_theta_args.kwargs[key] = tf.nest.map_structure(
_ToManualReplicate, val)
# Wrap non-trainable vars with StackedVarWrapperWithManualSharding, in
# case they are accessed directly in FProp (e.g., batch norm vars).
def _WrapWithManual(v):
if v.trainable:
return v
return var_tmp_wrappers.StackedVarWrapperWithManualSharding(v)
with self.TransformVarsTempContext(_WrapWithManual):
# Step seed should be incremented by p.num_stages.
with py_utils.StepSeedIncrementContext(p.num_stages):
with py_utils.GlobalStepContext(
_ToManualReplicate(py_utils.GetGlobalStep())):
# If there are any internal annotations in the stage, they will be
# subgrouped with manual partitioning on stage_mesh_dim.
with gshard_utils.ManualMeshDimContext(stage_mesh_dim):
one_stage_outputs, control_out, context_tensors = _BodyFProp(
one_stage_theta_args)
def _ToAuto(x):
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
full_shape = [p.num_stages] + x.shape
unspecified_dims = list(range(1, len(full_shape)))
sharding = gshard_utils.GetMeshSplitSharding(
device_mesh, [stage_mesh_dim] + [-1] * len(x.shape))
x = tf.expand_dims(x, 0)
return xla_sharding.manual_to_auto_spmd_partition(
x,
sharding.proto.SerializeToString(),
full_shape=full_shape,
single_dim=0,
unspecified_dims=unspecified_dims)
# Reset step seed to the last stage's final seed.
py_utils.ResetStepSeed(_ToAuto(py_utils.GetStepSeed())[-1])
# Convert aux losses to per-stage vector losses.
outputs = tf.nest.map_structure(_ToAuto, one_stage_outputs)
context_tensors = tf.nest.map_structure(_ToAuto, context_tensors)
return outputs, control_out, context_tensors
else:
stage_id = tf.range(p.num_stages)
microbatch_id = tf.maximum(iteration - stage_id,
tf.zeros([p.num_stages], dtype=stage_id.dtype))
def _KwargSlice(x):
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
return tf.gather(x, microbatch_id)
theta_args.kwargs = tf.nest.map_structure(_KwargSlice, kwargs)
for key, val in (kwargs_no_batch or {}).items():
theta_args.kwargs[key] = val
return tf.vectorized_map(
_BodyFProp, theta_args, fallback_to_while_loop=False)