def __init__()

in lingvo/core/recurrent.py [0:0]


  def __init__(self,
               cell_fn,
               cell_grad,
               stop_fn,
               theta,
               state0,
               inputs,
               extras,
               cell_type=None,
               accumulator_layer=None,
               implicit_captures=None,
               unused_acc_state=None,
               backward_cleanup=None):
    """RNN helper class.

    Args:
      cell_fn: A python function which computes:
         state1, extras = cell_fn(theta, state0, inputs[t, :])
      cell_grad: A python function which computes:
         dtheta, dstate0, dinputs[t, :] = cell_grad(
           theta, state0, inputs[t, :], extras, dstate1)
      stop_fn: A python function which computes: should_stop = stop_fn(t, theta,
        state0)
      theta: weights. A `.NestedMap`.
      state0: initial state. A `.NestedMap`.
      inputs: inputs. A `.NestedMap`.
      extras: A `.NestedMap` of Tensors. The 2nd return value of every
        invocation of cell_fn is a `.NestedMap` with matching keys and shapes of
        this 'extras'.
      cell_type: Cell type used in this class.
      accumulator_layer: If provided, then accumulators on this layer will be
        managed such that they carry to the final state in `FProp` and are
        disabled for gradients. Uses the state key `accumulators`.
      implicit_captures: A `.NestedMap` corresponding to implicit captures of
        the cell_fn. If empty/None, implicit captures are either not present or
        disallowed.
      unused_acc_state: If None, we assume every field of acc_state is consumed
        in the following timestamps. If True, None of the acc_state is consumed.
        And we reduce_sum each timestep's new state into a scalar. Note, this
        feature should be used with StackedRecurrent where we send out the new
        state to the other devices.
      backward_cleanup: An optional callback function (no argument) to be
        invoked after the backward pass. It returns a list of ops, which will
        run as control dependencies of d(inputs) on the backward path. Could be
        used to clean up side effects during recompute.
    """
    self._theta = theta
    self._state = state0
    self._inputs = inputs
    self._cell_fn = _DecorateCellFn(cell_fn, accumulator_layer)
    self._cell_grad = _DecorateCellGrad(cell_grad, accumulator_layer)
    self._stop_fn = stop_fn
    self._extras = extras
    if cell_type is not None:
      self._cell_type = cell_type
    else:
      self._cell_type = 'UnknownType'
    self._accumulator_layer = accumulator_layer
    self._implicit_captures = implicit_captures
    self._unused_acc_state = unused_acc_state
    self._backward_cleanup = backward_cleanup

    # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
    # Forward and Backward defined below) simply takes a list of
    # Tensors and returns a list of Tensors. When we pass in a
    # structure (a list of NestedMap of Tensors), we use Flatten to
    # convert the structure into a list of tensor. Conversely, the
    # following code often uses Pack to formulate a structure from a
    # list of tensors based on a "template".

    compiled = py_utils.use_xla()
    noinline = not compiled

    # state1, extras = cell_fn(theta, state0, inputs)
    def Fwd(theta, state0, inputs):
      py_utils.SetShapes(theta, self._theta)
      state1, extras = self._cell_fn(theta, state0, inputs)
      py_utils.AssertIsCompatible(state1, self._state)
      py_utils.AssertIsCompatible(extras, self._extras)
      return state1, extras

    # Wraps cell_fn in a TF Function as a for-loop's body.
    #
    # The loop state is composed of:
    #  t: The loop variable on the device. Timestep id.
    #  theta: the recurrent net's weights.
    #  state0: the previous recurrent state.
    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
    #  acc_state: Each timestep's computed new state is also stashed into
    #    acc_state.
    #  acc_extras: Each timestep's computed extras is stashed into acc_extras

    def ForwardLoopCond(loop_state):
      """The condition of forward loop."""
      should_continue = loop_state.t < loop_state.limit
      if self._stop_fn:
        should_continue = tf.math.logical_and(
            should_continue,
            tf.reduce_any(
                tf.math.logical_not(
                    self._stop_fn(loop_state.t, loop_state.theta,
                                  loop_state.state0))))
      return should_continue

    def ForwardLoopBody(loop_state):
      """The body of forward loop."""
      t = loop_state.t
      # external input at time step t.
      inputs_t = _Index(loop_state.inputs, t)
      loop_state.state0, extras = Fwd(loop_state.theta, loop_state.state0,
                                      inputs_t)
      # Saves state1 and extras in their accumulators.
      if not self._unused_acc_state:
        loop_state.acc_state = _Update(loop_state.acc_state, loop_state.state0,
                                       t)
      loop_state.acc_extras = _Update(loop_state.acc_extras, extras, t)
      loop_state.t = tf.add(t, 1)
      return loop_state

    # Forward calls ForwardLoopBody n times. Each time computes one
    # time step of the recurrent net.
    def Forward(args):
      """Forward pass of the recurrent net."""
      # The sequence length.
      pad_begin, pad_end = _SeqPaddingLength(args.inputs)
      slen_dim = _SeqLenDim(args.inputs)
      limit = slen_dim - pad_end

      # Creates accumulators for state0 and extras.
      if self._unused_acc_state:
        acc_state = _EmptyWithFixShape([slen_dim], args.state0)
      else:
        acc_state = _EmptyAcc(slen_dim, args.state0)
      acc_extras = _EmptyAcc(slen_dim, args.extras)

      if compiled:
        t = tf.cast(pad_begin, tf.int32)
        limit = tf.cast(limit, tf.int32)
      else:
        t = tf.cast(pad_begin, tf.int64)
        limit = tf.cast(limit, tf.int64)

      with py_utils.RemoveAssertContext(remove=noinline):
        run = py_utils.WhileLoop(
            ForwardLoopCond,
            ForwardLoopBody,
            loop_state=py_utils.NestedMap(
                t=t,
                limit=limit,
                theta=args.theta,
                state0=args.state0,
                inputs=args.inputs,
                acc_state=acc_state,
                acc_extras=acc_extras))
      return py_utils.NestedMap(
          limit=run.t,
          final_state=run.state0,
          acc_state=run.acc_state,
          acc_extras=run.acc_extras)

    # The per-step backward computes:
    #    d_theta, d_state0, d_inputs = cell_grad(
    #        theta, state0, inputs, extras, d_state1)
    # where d_state1 is the backprop-ed gradient for state1, and
    # extras is the computed by the forward step to facilitate the
    # backward step.
    def Bak(theta, state0, inputs, extras, d_state1):
      """Backward step."""
      py_utils.SetShapes(theta, self._theta)
      (dtheta, dstate0, dinputs,
       dcaptures) = self._cell_grad(theta, state0, inputs, extras, d_state1)
      py_utils.AssertIsCompatible(dtheta, self._theta)
      py_utils.AssertIsCompatible(dstate0, self._state)
      py_utils.AssertIsCompatible(dinputs, self._inputs)
      if dcaptures is None:
        # NOTE: Custom gradient fns can return None if they do not support
        # captured tensors. The return value is reserved for the future when
        # that may be supported.
        dcaptures = _EmptyLike(self._implicit_captures)
      py_utils.AssertIsCompatible(dcaptures, self._implicit_captures)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the call
      # to cell_grad() which adds to the captured list.
      _AssertSameTensors(py_utils.GetExtraInputs(),
                         self._implicit_captures.Flatten())

      return [dtheta, dstate0, dinputs, dcaptures]

    # Wraps cell_grad gradient function in a TF Function as a
    # for-loop's body for the Backward pass.
    #
    # The loop state is composed of:
    #  t: The loop variable on the device. Timestep id.
    #  state0: the initial state for the entire backward loop.
    #  theta: the recurrent net's weights.
    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
    #  acc_state: Each timestep's computed new state was stashed into
    #    acc_state by the Forward pass.
    #  acc_extras: Each timestep's computed extras was stashed into
    #    acc_extras by the Forward pass.
    #  d_theta: All timestep's gradient for theta is accumulated (added) into
    #      d_theta.
    #  d_state1: The backprop-ed gradient for the new stated computed by
    #      timestep t.
    #  d_inputs: d_inputs[t, :] is populated by the backward time step t.
    #  d_acc_state: The backprop-ed gradient for acc_state.
    #  d_captured: All timestep's gradient for theta is accumulated (added)
    #      into d_captured.

    def BackwardLoopCond(loop_state):
      """Backward loop condition function."""
      return loop_state.t >= loop_state.limit

    def BackwardLoopBody(loop_state):
      """Backward loop body function."""
      t = loop_state.t
      # The input recurrent state for time step t is previous time step's
      # output, or the original state0 when on time step 0.
      state_from_acc = _Index(loop_state.acc_state,
                              tf.maximum(tf.constant(0, t.dtype), t - 1))
      state0 = tf.cond(
          tf.equal(t, tf.constant(0, t.dtype)),
          true_fn=lambda: loop_state.state0,
          false_fn=lambda: state_from_acc)

      # The external inputs for time step t.
      inputs_t = _Index(loop_state.inputs, t)
      # The extras for time step t.
      extras_t = _Index(loop_state.acc_extras, t)

      d_state1 = _Add(_Index(loop_state.d_acc_state, t), loop_state.d_state1)
      (d_theta_t, loop_state.d_state1, d_inputs_t,
       d_captured_t) = Bak(loop_state.theta, state0, inputs_t, extras_t,
                           d_state1)

      if self._unused_acc_state:
        # XLA IF op requires the same shape for if and else branches.
        loop_state.d_state1 = loop_state.d_state1.Transform(tf.reduce_sum)
      loop_state.d_theta = _Add(loop_state.d_theta, d_theta_t)
      loop_state.d_inputs = _Update(loop_state.d_inputs, d_inputs_t, t)
      loop_state.d_captured = _Add(loop_state.d_captured, d_captured_t)
      loop_state.t = tf.subtract(t, 1)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the call
      # to Bak() which adds to the captured list.
      _AssertSameTensors(py_utils.GetExtraInputs(),
                         self._implicit_captures.Flatten())

      return loop_state

    # Backward calls BackwardLoopBody n times. Each time computes the backprop
    # for one time step of the recurrent net.
    def Backward(xs, ys, dys):
      """Backward pass for the recurrent net.

      Args:
        xs: inputs to the forward operation.
        ys: outputs of the forward operation.
        dys: gradients to the outputs of the forward operation.

      Returns:
        Gradients to the inputs of the forward operation.
      """
      # Accumulators for gradients.
      d_theta = _EmptyLike(xs.theta)
      d_inputs = _EmptyLike(xs.inputs)
      d_captured = _EmptyLike(self._implicit_captures)

      # The sequence length.
      pad_begin, _ = _SeqPaddingLength(xs.inputs)
      limit = pad_begin

      if compiled:
        limit = tf.cast(limit, tf.int32)
      else:
        limit = tf.cast(limit, tf.int64)

      state0 = xs.state0
      d_state1 = dys.final_state
      if self._unused_acc_state:
        # XLA While op requires the same shape for the init and carry on
        # values.
        state0 = state0.Transform(tf.reduce_sum)
        d_state1 = d_state1.Transform(tf.reduce_sum)

      with py_utils.RemoveAssertContext(remove=noinline):
        run = py_utils.WhileLoop(
            cond=BackwardLoopCond,
            body=BackwardLoopBody,
            loop_state=py_utils.NestedMap(
                t=ys.limit - 1,
                limit=limit,
                theta=xs.theta,
                state0=state0,
                inputs=xs.inputs,
                acc_state=ys.acc_state,
                acc_extras=ys.acc_extras,
                d_theta=d_theta,
                d_state1=d_state1,
                d_inputs=d_inputs,
                d_acc_state=dys.acc_state,
                d_captured=d_captured))

      d_state0 = run.d_state1
      if self._unused_acc_state:
        # Match the shape of gradient of the init_state.
        d_state0 = self._state.Transform(tf.zeros_like)

      d_inputs = run.d_inputs

      if self._backward_cleanup is not None:
        with tf.control_dependencies(d_inputs.Flatten()):
          control_before = tf.no_op()
        with tf.control_dependencies([control_before]):
          with tf.control_dependencies(self._backward_cleanup()):
            d_inputs = d_inputs.Transform(tf.identity)

      # The `extra` input in the Forward function is actually an output of the
      # function. It was supplied as an input only to create acc_extras with
      # proper shape, so its gradients should be zero.
      return py_utils.NestedMap(
          d_theta=run.d_theta,
          d_state0=d_state0,
          d_inputs=d_inputs,
          d_extras=_EmptyLike(self._extras)), run.d_captured

    # Forward arguments.
    self._fwd_args = py_utils.NestedMap(
        theta=self._theta,
        state0=self._state,
        inputs=self._inputs,
        extras=self._extras)

    # pylint: disable=protected-access
    device_funcs = tf.get_default_graph()._device_functions_outer_to_inner
    self._caller_device = device_funcs[-1] if device_funcs else None
    # pylint: enable=protected-access

    self._forward = Forward
    self._backward = Backward