in lingvo/core/gshard_layers.py [0:0]
def FPropFn(self,
theta,
fn_name,
*args,
padded_per_stage_states,
kwargs_no_batch=None,
**kwargs):
"""Runs forward pass on a specified function."""
p = self.params
if p.unroll == 'always' or (self.do_eval and p.unroll == 'eval_only'):
return self._unrolled_fprop(theta, *args, **kwargs)
if p.per_stage_vars:
all_iters = [theta['body_iter_%05d' % i] for i in range(p.num_stages)]
theta_body = tf.nest.map_structure(lambda *t: tf.stack(list(t)),
*all_iters)
def _PassthroughVarSharding(x):
split_dim = self._FindPerStageVarShardingDim(x.shape[1:])
if split_dim < 0:
x = xla_sharding.replicate(x, use_sharding_op=True)
else:
# The stacked theta has a leading dim, so we use split_dim + 1.
x = gshard_utils.Split(
x, split_dim + 1, p.num_stages, use_sharding_op=True)
return x
if p.shard_stages_1d:
# Pass through the per-stage variables' sharding to the stacked theta.
# Later, this will be resharded on the leading stage dim. This explicit
# resharding makes sure that resharding happens after the concat, and
# concat partitioning is trivial on the pass-through dim.
theta_body = tf.nest.map_structure(_PassthroughVarSharding, theta_body)
else:
def _AnnotateTheta(x, var):
return xla_sharding.copy_sharding(var, x, use_sharding_op=True)
theta_body = tf.nest.map_structure(_AnnotateTheta, theta.body,
self.vars.body)
needs_microbatching = False
if p.num_microbatches is None:
num_microbatches = py_utils.Flatten(args)[0].get_shape().as_list()[0]
if p.microbatch_size is not None:
batch_size = num_microbatches
assert batch_size % p.microbatch_size == 0
num_microbatches = batch_size // p.microbatch_size
needs_microbatching = True
else:
num_microbatches = p.num_microbatches
needs_microbatching = True
if needs_microbatching:
def _ToMicrobatches(x):
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
x_shape = py_utils.GetShape(x)
assert x_shape[0] % num_microbatches == 0
# We first put num_microbatches in the inner dimension then transpose
# it. This allows the sharding on the batch (if any) to be propagated
# to the microbatch dimension. We cannot shard the num_microbatches
# dimension, since it's indexed by the loop iteration.
reshaped = tf.reshape(
x, [x_shape[0] // num_microbatches, num_microbatches] + x_shape[1:])
return tf.transpose(reshaped,
[1, 0] + list(range(2, len(reshaped.shape))))
args = tf.nest.map_structure(_ToMicrobatches, args)
kwargs = tf.nest.map_structure(_ToMicrobatches, kwargs)
def _MaybeReplicateNumMicrobatches(x):
# Mark the num_microbatches dim replicated.
if not isinstance(x, (tf.Operation, tf.Tensor)):
return x
if p.shard_stages_1d:
return gshard_utils.Replicate(x)
if p.pipeline_stage_mesh_dim is not None:
# Partially specify that only dim 0 is replicated.
return gshard_utils.MeshSplit(
x,
p.device_mesh, [-1] * len(x.shape),
unspecified_dims=list(range(1, len(x.shape))))
return x
# Replicate the input as the layer is only sharded on the stage dimension.
args = tf.nest.map_structure(_MaybeReplicateNumMicrobatches, args)
kwargs = tf.nest.map_structure(_MaybeReplicateNumMicrobatches, kwargs)
if p.shard_stages_1d:
def _SplitStages(x):
return gshard_utils.Split(x, 0, p.num_stages)
theta_body = tf.nest.map_structure(_SplitStages, theta_body)
# Adds a `stages` dimension after the leading num_microbatches to the inputs
# which will be sharded. Also pad the leading num_microbatches dimension by
# num_stages - 1 to match loop iteration count, which corresponds to the
# bubbles between forward and backward passes.
#
# Inputs are not the loop state: they are not changed during the loop. The
# state (shifting buffer) does not have a num_microbatches dimension.
def _PadInput(inp):
if not isinstance(inp, (tf.Operation, tf.Tensor)):
return inp
# Takes input tensor of shape [num_microbatches, ...] and returns padded
# tensor of shape [num_iterations_with_bubbles, num_stages, ...],
# where num_stages is a new dimension.
with_new_dim = tf.expand_dims(inp, 1)
padded = self._PadMicrobatchesInternal(with_new_dim, pad_stages=True)
assert len(padded.shape) == len(inp.shape) + 1
assert padded.shape[1] == p.num_stages
return padded
padded_inputs = tf.nest.map_structure(_PadInput, args)
padded_shapes = tf.nest.map_structure(
lambda x: None if x is None else x.shape, padded_inputs)
remove_first_dim = lambda x: None if x is None else x[1:]
state_shapes = tf.nest.map_structure(remove_first_dim, padded_shapes)
def _ArgsToState(arg_list):
"""Returns a NestedMap from a list of FProp args."""
state = py_utils.NestedMap()
# Maintains a mapping from arg_idx to tensor. states cannot contain None
# tensors.
for idx in range(len(padded_inputs)):
if isinstance(arg_list[idx], py_utils.NestedMap):
# Make sure each value in the NestedMap is a tensor.
if not all(isinstance(t, tf.Tensor) for t in arg_list[idx].Flatten()):
raise ValueError(
'Each value in the input NestedMap must be a tensor.')
if arg_list[idx] is not None:
state['_s{}'.format(idx)] = arg_list[idx]
return state
def _StateToArgs(state, shapes):
"""Returns a list of FProp args from a NestedMap."""
arg_list = []
for idx in range(len(padded_inputs)):
attr = '_s{}'.format(idx)
arg_list.append(state[attr] if attr in state else None)
tf.nest.map_structure(lambda x, s: x.set_shape(s), arg_list[-1],
shapes[idx])
return arg_list
self._tpu_summary_structure = None
def _CellFn(theta, state0, inputs_and_per_stage_states):
"""Recurrent cell function wrapper of body.FProp."""
inputs = inputs_and_per_stage_states.inputs
per_stage_states = inputs_and_per_stage_states.per_stage_states
tf.nest.map_structure(lambda x, y: x.set_shape(y.shape[1:]),
per_stage_states, padded_per_stage_states)
state0.iteration.set_shape([])
state0.aux_loss.set_shape([])
def _SelectInput(state, inp):
in_mask = tf.equal(tf.range(p.num_stages), 0)
if p.circular_repeat == 1:
# The state is aligned to previous stage. We shift it to the right by
# 1 stage. If the stage dimension is partitioned in GShard, this will
# cause a collective-permute being added.
padding = [[1, 0]] + [[0, 0]] * (len(state.shape) - 1)
shifted_state = tf.pad(state, padding)[0:p.num_stages, ...]
else:
# Rotate the circular buffer. If the stage dimension is partitioned in
# GShard, this will cause a collective-permute being added.
shifted_state = tf.concat([state[-1:], state[:-1]], axis=0)
in_segment_offset = tf.math.mod(state0.iteration,
p.circular_repeat * p.num_stages)
in_mask = tf.logical_and(in_mask,
tf.less(in_segment_offset, p.num_stages))
in_mask = tf.reshape(in_mask,
[p.num_stages] + [1] * (len(inp.shape) - 1))
return tf.where(
tf.broadcast_to(in_mask, shifted_state.shape),
tf.cast(inp, shifted_state.dtype), shifted_state)
selected_inputs = tf.nest.map_structure(
_SelectInput, _StateToArgs(state0.args, state_shapes),
_StateToArgs(inputs, state_shapes))
# Restore non-trainable vars to state0, because it can be called in the
# backward pass.
assigns = tf.nest.map_structure(lambda v, s: v.assign(s),
self._non_trainable_vars,
state0.non_trainable_vars)
def _BodyFPropWithAuxLoss():
with py_utils.AuxLossContext(reentrant=True) as al_ctx:
with py_utils.TpuSummaryTensorContext():
fprop_outputs, ctrl = self.BodyFProp(
theta,
fn_name,
state0.iteration,
num_microbatches,
*selected_inputs,
*per_stage_states,
kwargs_no_batch=kwargs_no_batch,
**kwargs)
context_tensors = py_utils.NestedMap(
tpu_summary_tensors=py_utils.GetTpuSummaryTensors(),
aux_losses=al_ctx.aux_losses)
return fprop_outputs, ctrl, context_tensors
if assigns:
# Group the dependencies into a single no_op to avoid quadratic number
# of control edges.
with tf.control_dependencies(assigns):
ctrl_before = tf.no_op()
with tf.control_dependencies([ctrl_before]):
fprop_outputs, ctrl, context_tensors = _BodyFPropWithAuxLoss()
else:
fprop_outputs, ctrl, context_tensors = _BodyFPropWithAuxLoss()
fprop_outputs = _ToTuple(fprop_outputs)
assert len(fprop_outputs) == len(selected_inputs) + len(per_stage_states)
# Passes fprop outputs to the next layer through state.
state1 = py_utils.NestedMap(
args=_ArgsToState(fprop_outputs[:len(selected_inputs)]),
per_stage_states=list(fprop_outputs[len(selected_inputs):]),
iteration=state0.iteration + tf.constant(1, dtype=tf.int32))
# v and v0 are the new and old values for each stage with leading dim
# num_stages. Selects v if it's a valid iteration and v0 if it's a bubble.
def _NewValueIfValidIter(v, v0):
mb_id, _ = self._MicrobatchAndRepeatIDs(state0.iteration)
valid_iter = tf.logical_and(
tf.less(mb_id, num_microbatches),
tf.greater_equal(state0.iteration, tf.range(p.num_stages)))
with tf.control_dependencies([ctrl]):
v1 = tf.identity(v)
return tf.where(
tf.broadcast_to(
tf.reshape(valid_iter,
[p.num_stages] + [1] * (len(v.shape) - 1)), v.shape),
v1, v0)
# Pass state0.non_trainable_vars or updated values depending on whether
# it is a bubble iteration.
state1.non_trainable_vars = tf.nest.map_structure(
_NewValueIfValidIter, self._non_trainable_vars,
state0.non_trainable_vars)
if context_tensors.aux_losses:
context_tensors.aux_losses = tf.add_n([
tf.cast(l, state0.aux_loss.dtype)
for l in context_tensors.aux_losses
])
state1.aux_loss = state0.aux_loss + tf.reduce_sum(
_NewValueIfValidIter(context_tensors.aux_losses,
tf.zeros_like(context_tensors.aux_losses)))
else:
state1.aux_loss = state0.aux_loss
if self._non_trainable_vars:
# Skip summary tensors when there are non-trainable vars. Recurrent()
# uses reflection to figure out the signature, but that causes problems
# for stateful computation.
extras = py_utils.NestedMap()
else:
self._tpu_summary_structure = tf.nest.map_structure(
lambda _: None, context_tensors.tpu_summary_tensors)
# Set the value/weight of the summary tensors to 0 for bubble
# iterations.
context_tensors.tpu_summary_tensors = tf.nest.map_structure(
lambda x: _NewValueIfValidIter(x, tf.zeros_like(x)),
context_tensors.tpu_summary_tensors)
extras = py_utils.NestedMap(
tpu_summary_tensors=tf.nest.flatten(
context_tensors.tpu_summary_tensors))
return state1, extras
with tf.name_scope(p.name):
inputs_nmap = _ArgsToState(padded_inputs)
def _CreateInitState(inp):
return tf.zeros(py_utils.GetShape(inp)[1:], dtype=inp.dtype)
# Add FProp arg list to state0.
state0 = py_utils.NestedMap(
args=tf.nest.map_structure(_CreateInitState, inputs_nmap),
per_stage_states=tf.nest.map_structure(_CreateInitState,
padded_per_stage_states),
iteration=tf.constant(0, dtype=tf.int32),
aux_loss=tf.constant(0, dtype=tf.float32),
non_trainable_vars=tf.nest.map_structure(tf.identity,
self._non_trainable_vars))
final_non_trainable_var_values = None
def _RestoreVarsToFinal():
assert final_non_trainable_var_values is not None
assigns = tf.nest.map_structure(lambda v, x: v.assign(x),
self._non_trainable_vars,
final_non_trainable_var_values)
with tf.control_dependencies(assigns):
return [tf.no_op()]
# Runs body.FProp k times using Recurrent where k = dim 0 of inputs_nmap.
accum, outputs, accum_extras = recurrent.Recurrent(
theta=theta_body,
state0=state0,
inputs=py_utils.NestedMap(
inputs=inputs_nmap, per_stage_states=padded_per_stage_states),
cell_fn=_CellFn,
# Use {} to avoid reflection call that affects non trainable vars.
extras={} if self._non_trainable_vars else None,
allow_implicit_capture=p.allow_implicit_capture,
allowed_tensor_captures=self._non_trainable_vars + [
x for x in py_utils.Flatten([kwargs, kwargs_no_batch])
if isinstance(x, (tf.Operation, tf.Tensor))
],
backward_cleanup=(_RestoreVarsToFinal
if self._non_trainable_vars else None),
return_acc_extras=True)
# Retrieves fprop outputs.
def _ExtractLastStage(outp):
if p.circular_repeat == 1:
return outp[p.num_stages - 1:, -1, ...]
else:
# See the class documuentation for circular pipeline.
bubble_removed = outp[p.num_stages - 1:, -1, ...]
num_segments = (num_microbatches + p.num_stages - 1) // p.num_stages
segmented = tf.reshape(
bubble_removed,
[num_segments, p.circular_repeat, p.num_stages] + outp.shape[2:])
return tf.reshape(segmented[:, -1,
...], [num_segments * p.num_stages] +
outp.shape[2:])[:num_microbatches]
final_non_trainable_var_values = outputs.non_trainable_vars
output_tensors = tf.nest.map_structure(
_ExtractLastStage, _StateToArgs(accum.args, padded_shapes))
output_per_stage_states = accum.per_stage_states
if self._non_trainable_vars:
with tf.control_dependencies(_RestoreVarsToFinal()):
output_tensors = tf.nest.map_structure(tf.identity, output_tensors)
output_per_stage_states = tf.nest.map_structure(
tf.identity, output_per_stage_states)
aux_loss_context = py_utils.AuxLossContext.Current()
if aux_loss_context:
if p.aux_loss_microbatch_accumulation == 'mean':
outputs.aux_loss = tf.div(outputs.aux_loss, num_microbatches)
aux_loss_context.AddLoss(outputs.aux_loss)
if self._tpu_summary_structure is not None:
tpu_summary_tensors = tf.nest.pack_sequence_as(
self._tpu_summary_structure, accum_extras.tpu_summary_tensors)
for key, (value, weight) in tpu_summary_tensors.items():
for stage_id in range(p.num_stages):
v, w = py_utils.WeightedAvg(value[:, stage_id, ...],
weight[:, stage_id, ...])
py_utils.AddTpuSummaryTensor('%s/stage_%s' % (key, stage_id), v, w)
output_tensors = tf.nest.map_structure(_MaybeReplicateNumMicrobatches,
output_tensors)
if needs_microbatching:
def _ToBatches(x):
x_shape = py_utils.GetShape(x)
transposed = tf.transpose(x, [1, 0] + list(range(2, len(x_shape))))
return tf.reshape(transposed,
[num_microbatches * x_shape[1]] + x_shape[2:])
output_tensors = tf.nest.map_structure(_ToBatches, output_tensors)
output_tensors += output_per_stage_states
return output_tensors[0] if len(output_tensors) == 1 else tuple(
output_tensors)