in lingvo/core/recurrent.py [0:0]
def _StackedRecurrent(devices, cell_fns, cell_grads, cell_outs, cell_out_grads,
thetas, init_states, inputs, accumulator_layers,
unused_acc_state):
"""Implementation of StackedRecurrent, see StackedRecurrent for details."""
num_layers = len(devices)
assert num_layers
def _MakeList(fns):
if not isinstance(fns, (list, tuple)):
return [fns] * num_layers
else:
assert num_layers == len(fns)
return fns
cell_fns = _MakeList(cell_fns)
cell_grads = _MakeList(cell_grads)
cell_outs = _MakeList(cell_outs)
cell_out_grads = _MakeList(cell_out_grads)
accumulator_layers = accumulator_layers or [None] * num_layers
assert num_layers == len(thetas)
assert all(isinstance(x, py_utils.NestedMap) for x in thetas)
assert num_layers == len(init_states)
assert all(isinstance(x, py_utils.NestedMap) for x in init_states)
assert isinstance(inputs, py_utils.NestedMap)
if py_utils.use_tpu():
# If this error happens, the number of splits must be increased (e.g.
# worker_split_size in trainer/tpu.sh), or the number of rnn layers
# decreased.
# TODO(cwhipkey): lift this restriction by grouping layers by device and
# having a device handle a contiguous run of layers, and have them loop
# over the layers in the cell fns.
assert len(devices) == len(set(devices)), (
'StackedRecurrent must provide a different device for each layer '
'when run on TPU. devices passed were: %s' % str(devices))
if num_layers == 1:
# Simple case, just use Recurrent() directly.
with tf.device(devices[0]):
acc_states, final = Recurrent(
theta=thetas[0],
state0=init_states[0],
inputs=inputs,
cell_fn=cell_fns[0],
cell_grad=cell_grads[0],
accumulator_layer=accumulator_layers[0])
# Just the accumulated states.
return cell_outs[0](acc_states), final
# We add explicit data dependencies between layer-i's theta/state0
# and layer-(i-1)'s theta/state0, layer-0's theta/state0 has an
# explicit data dependency on inputs. These extra data dependencies
# ensure that if layer-i's theta/state0 is used in tf.gradient, all
# layers above's backprop are triggered.
prev = [inputs]
for i in range(num_layers):
with tf.device(devices[i]):
thetas[i], init_states[i] = _DependsOn([thetas[i], init_states[i]], prev)
prev = [thetas[i], init_states[i]]
def ExpectedOutputOfLayers():
"""Estimate what tensor dtypes and shapes output by each layer."""
def ZerosLikeRequireShape(t):
assert t.shape.is_fully_defined()
return tf.zeros_like(t)
if py_utils.use_tpu():
transform_fn = ZerosLikeRequireShape
else:
transform_fn = tf.zeros_like
expected_output_by_layers = []
xs = _Index(inputs, 0)
for i in range(num_layers):
# Disable accumulators and step_seed since this is not a real call to
# cell_fns[i]. They will be re-enabled in _Recurrent.<F24><F25>
if accumulator_layers[i]:
accumulator_layers[i].accumulators.Transform(lambda x: x.Disable())
step_seed = py_utils.GetStepSeed()
state1, extras = cell_fns[i](thetas[i], init_states[i], xs)
py_utils.ResetStepSeed(step_seed)
# only dtype and shape is needed.
xs = cell_outs[i](state1)
expected_output_by_layers += [
py_utils.NestedMap(
xs=xs.Transform(transform_fn),
extras=extras.Transform(transform_fn))
]
return expected_output_by_layers
expected_output_by_layers = ExpectedOutputOfLayers()
# Sequence length. We assume it's a grid we are building.
slen_dim = _SeqLenDim(inputs)
assert num_layers >= 2
layers = []
padding = FlattenPadding(inputs.get('padding', None))
# Builds the input layer.
out_links = _CreateLinks(expected_output_by_layers[0].xs,
DevicePair(devices[0], devices[1]))
# Enable accumulators. Note that this must happen prior to the initial
# _AugmentState() below or it will initialize with defaults.
for accumulator_layer in accumulator_layers:
if accumulator_layer:
accumulator_layer.accumulators.Transform(lambda x: x.Enable())
inp_l = _Input(
cell_fn=cell_fns[0],
cell_grad=cell_grads[0],
cell_out=cell_outs[0],
cell_out_grad=cell_out_grads[0],
theta=thetas[0],
state0=_AugmentState(init_states[0].DeepCopy(), accumulator_layers[0]),
accumulator_layer=accumulator_layers[0],
inputs=inputs,
extras=expected_output_by_layers[0].extras,
out_links=out_links,
unused_acc_state=unused_acc_state)
layers += [inp_l]
# Builds the intermediate layers.
for i in range(1, num_layers - 1):
in_links = out_links
out_links = _CreateLinks(expected_output_by_layers[i].xs,
DevicePair(devices[i], devices[i + 1]))
mid_l = _Middle(
cell_fn=cell_fns[i],
cell_grad=cell_grads[i],
cell_out=cell_outs[i],
cell_out_grad=cell_out_grads[i],
theta=thetas[i],
state0=_AugmentState(init_states[i].DeepCopy(), accumulator_layers[i]),
accumulator_layer=accumulator_layers[i],
in_links=in_links,
padding=padding,
slen_dim=slen_dim,
per_step_inputs=expected_output_by_layers[i - 1].xs,
extras=expected_output_by_layers[i].extras,
out_links=out_links,
unused_acc_state=unused_acc_state)
layers += [mid_l]
# Builds the final output layer.
in_links = out_links
del out_links
out_l = _Output(
cell_fn=cell_fns[-1],
cell_grad=cell_grads[-1],
theta=thetas[-1],
state0=_AugmentState(init_states[-1].DeepCopy(), accumulator_layers[-1]),
accumulator_layer=accumulator_layers[-1],
in_links=in_links,
padding=padding,
slen_dim=slen_dim,
per_step_inputs=expected_output_by_layers[-2].xs,
extras=expected_output_by_layers[-1].extras)
layers += [out_l]
assert len(layers) == num_layers
anchor = 0
final_states = []
for (dev, layer) in zip(devices, layers):
# Computes each layer on their designated device.
with tf.device(dev):
acc_states, final = layer.Compute() # Don't care of final state yet.
final_states.append(final)
# We add every number output by the layer (s) and computes a
# zero scalar: (s - s), as an anchor. Anchors are added
# sequentially and added to the final layer's output. This way,
# we ensure that the final output depends on every previous
# layer through data dependencies. This is a hack to ensure that
# tf.gradient will follow some data dependencies path to start
# the Backward loop for each layer.
#
# TODO(zhifengc): We can write, if we have nil & first ops:
# anchor += [nil(py_utils.Flatten(acc_states))]
# And finally,
# return acc_states.Transform(lambda x: first(x, anchor))
def ComputeAnchor(x):
# For each
s = tf.add_n([tf.reduce_sum(_) for _ in x.Flatten()])
return s - s
anchor = ComputeAnchor(acc_states) + anchor
# The last layer's output is the real output that matters. However,
# to make the previous layers backprop work, we need to make sure
# the returned value has data dependencies on the previous layers.
# 'anchor' is guaranteed to be a scalar 0 and hence adding it to the
# final output does not change its numerical value.
with tf.device(devices[-1]):
outputs = cell_outs[-1](acc_states.Transform(lambda x: x + anchor))
# TODO(b/129159299): The ResetStepSeed below is needed to work around this
# bug, which is a problem with global tensors being shared by different
# inference graphs. It should be removed once the bug is fixed.
py_utils.MaybeResetStepSeedFromScope()
return outputs, final_states