def _StackedRecurrent()

in lingvo/core/recurrent.py [0:0]


def _StackedRecurrent(devices, cell_fns, cell_grads, cell_outs, cell_out_grads,
                      thetas, init_states, inputs, accumulator_layers,
                      unused_acc_state):
  """Implementation of StackedRecurrent, see StackedRecurrent for details."""
  num_layers = len(devices)
  assert num_layers

  def _MakeList(fns):
    if not isinstance(fns, (list, tuple)):
      return [fns] * num_layers
    else:
      assert num_layers == len(fns)
      return fns

  cell_fns = _MakeList(cell_fns)
  cell_grads = _MakeList(cell_grads)
  cell_outs = _MakeList(cell_outs)
  cell_out_grads = _MakeList(cell_out_grads)
  accumulator_layers = accumulator_layers or [None] * num_layers
  assert num_layers == len(thetas)
  assert all(isinstance(x, py_utils.NestedMap) for x in thetas)
  assert num_layers == len(init_states)
  assert all(isinstance(x, py_utils.NestedMap) for x in init_states)
  assert isinstance(inputs, py_utils.NestedMap)

  if py_utils.use_tpu():
    # If this error happens, the number of splits must be increased (e.g.
    # worker_split_size in trainer/tpu.sh), or the number of rnn layers
    # decreased.
    # TODO(cwhipkey): lift this restriction by grouping layers by device and
    # having a device handle a contiguous run of layers, and have them loop
    # over the layers in the cell fns.
    assert len(devices) == len(set(devices)), (
        'StackedRecurrent must provide a different device for each layer '
        'when run on TPU. devices passed were: %s' % str(devices))

  if num_layers == 1:
    # Simple case, just use Recurrent() directly.
    with tf.device(devices[0]):
      acc_states, final = Recurrent(
          theta=thetas[0],
          state0=init_states[0],
          inputs=inputs,
          cell_fn=cell_fns[0],
          cell_grad=cell_grads[0],
          accumulator_layer=accumulator_layers[0])
      # Just the accumulated states.
      return cell_outs[0](acc_states), final

  # We add explicit data dependencies between layer-i's theta/state0
  # and layer-(i-1)'s theta/state0, layer-0's theta/state0 has an
  # explicit data dependency on inputs.  These extra data dependencies
  # ensure that if layer-i's theta/state0 is used in tf.gradient, all
  # layers above's backprop are triggered.
  prev = [inputs]
  for i in range(num_layers):
    with tf.device(devices[i]):
      thetas[i], init_states[i] = _DependsOn([thetas[i], init_states[i]], prev)
    prev = [thetas[i], init_states[i]]

  def ExpectedOutputOfLayers():
    """Estimate what tensor dtypes and shapes output by each layer."""

    def ZerosLikeRequireShape(t):
      assert t.shape.is_fully_defined()
      return tf.zeros_like(t)

    if py_utils.use_tpu():
      transform_fn = ZerosLikeRequireShape
    else:
      transform_fn = tf.zeros_like

    expected_output_by_layers = []
    xs = _Index(inputs, 0)
    for i in range(num_layers):
      # Disable accumulators and step_seed since this is not a real call to
      # cell_fns[i]. They will be re-enabled in _Recurrent.<F24><F25>
      if accumulator_layers[i]:
        accumulator_layers[i].accumulators.Transform(lambda x: x.Disable())
      step_seed = py_utils.GetStepSeed()
      state1, extras = cell_fns[i](thetas[i], init_states[i], xs)
      py_utils.ResetStepSeed(step_seed)
      # only dtype and shape is needed.
      xs = cell_outs[i](state1)
      expected_output_by_layers += [
          py_utils.NestedMap(
              xs=xs.Transform(transform_fn),
              extras=extras.Transform(transform_fn))
      ]
    return expected_output_by_layers

  expected_output_by_layers = ExpectedOutputOfLayers()

  # Sequence length. We assume it's a grid we are building.
  slen_dim = _SeqLenDim(inputs)

  assert num_layers >= 2
  layers = []

  padding = FlattenPadding(inputs.get('padding', None))

  # Builds the input layer.
  out_links = _CreateLinks(expected_output_by_layers[0].xs,
                           DevicePair(devices[0], devices[1]))

  # Enable accumulators. Note that this must happen prior to the initial
  # _AugmentState() below or it will initialize with defaults.
  for accumulator_layer in accumulator_layers:
    if accumulator_layer:
      accumulator_layer.accumulators.Transform(lambda x: x.Enable())

  inp_l = _Input(
      cell_fn=cell_fns[0],
      cell_grad=cell_grads[0],
      cell_out=cell_outs[0],
      cell_out_grad=cell_out_grads[0],
      theta=thetas[0],
      state0=_AugmentState(init_states[0].DeepCopy(), accumulator_layers[0]),
      accumulator_layer=accumulator_layers[0],
      inputs=inputs,
      extras=expected_output_by_layers[0].extras,
      out_links=out_links,
      unused_acc_state=unused_acc_state)
  layers += [inp_l]

  # Builds the intermediate layers.
  for i in range(1, num_layers - 1):
    in_links = out_links
    out_links = _CreateLinks(expected_output_by_layers[i].xs,
                             DevicePair(devices[i], devices[i + 1]))
    mid_l = _Middle(
        cell_fn=cell_fns[i],
        cell_grad=cell_grads[i],
        cell_out=cell_outs[i],
        cell_out_grad=cell_out_grads[i],
        theta=thetas[i],
        state0=_AugmentState(init_states[i].DeepCopy(), accumulator_layers[i]),
        accumulator_layer=accumulator_layers[i],
        in_links=in_links,
        padding=padding,
        slen_dim=slen_dim,
        per_step_inputs=expected_output_by_layers[i - 1].xs,
        extras=expected_output_by_layers[i].extras,
        out_links=out_links,
        unused_acc_state=unused_acc_state)
    layers += [mid_l]

  # Builds the final output layer.
  in_links = out_links
  del out_links
  out_l = _Output(
      cell_fn=cell_fns[-1],
      cell_grad=cell_grads[-1],
      theta=thetas[-1],
      state0=_AugmentState(init_states[-1].DeepCopy(), accumulator_layers[-1]),
      accumulator_layer=accumulator_layers[-1],
      in_links=in_links,
      padding=padding,
      slen_dim=slen_dim,
      per_step_inputs=expected_output_by_layers[-2].xs,
      extras=expected_output_by_layers[-1].extras)
  layers += [out_l]

  assert len(layers) == num_layers

  anchor = 0
  final_states = []
  for (dev, layer) in zip(devices, layers):
    # Computes each layer on their designated device.
    with tf.device(dev):
      acc_states, final = layer.Compute()  # Don't care of final state yet.
      final_states.append(final)

      # We add every number output by the layer (s) and computes a
      # zero scalar: (s - s), as an anchor. Anchors are added
      # sequentially and added to the final layer's output. This way,
      # we ensure that the final output depends on every previous
      # layer through data dependencies. This is a hack to ensure that
      # tf.gradient will follow some data dependencies path to start
      # the Backward loop for each layer.
      #
      # TODO(zhifengc): We can write, if we have nil & first ops:
      #   anchor += [nil(py_utils.Flatten(acc_states))]
      # And finally,
      #   return acc_states.Transform(lambda x: first(x, anchor))
      def ComputeAnchor(x):
        # For each
        s = tf.add_n([tf.reduce_sum(_) for _ in x.Flatten()])
        return s - s

      anchor = ComputeAnchor(acc_states) + anchor

  # The last layer's output is the real output that matters.  However,
  # to make the previous layers backprop work, we need to make sure
  # the returned value has data dependencies on the previous layers.
  # 'anchor' is guaranteed to be a scalar 0 and hence adding it to the
  # final output does not change its numerical value.
  with tf.device(devices[-1]):
    outputs = cell_outs[-1](acc_states.Transform(lambda x: x + anchor))

  # TODO(b/129159299): The ResetStepSeed below is needed to work around this
  # bug, which is a problem with global tensors being shared by different
  # inference graphs. It should be removed once the bug is fixed.
  py_utils.MaybeResetStepSeedFromScope()

  return outputs, final_states