in lingvo/core/recurrent.py [0:0]
def __init__(self,
cell_fn,
cell_grad,
stop_fn,
theta,
state0,
inputs,
extras,
cell_type=None,
accumulator_layer=None,
implicit_captures=None,
unused_acc_state=None,
backward_cleanup=None):
"""RNN helper class.
Args:
cell_fn: A python function which computes:
state1, extras = cell_fn(theta, state0, inputs[t, :])
cell_grad: A python function which computes:
dtheta, dstate0, dinputs[t, :] = cell_grad(
theta, state0, inputs[t, :], extras, dstate1)
stop_fn: A python function which computes: should_stop = stop_fn(t, theta,
state0)
theta: weights. A `.NestedMap`.
state0: initial state. A `.NestedMap`.
inputs: inputs. A `.NestedMap`.
extras: A `.NestedMap` of Tensors. The 2nd return value of every
invocation of cell_fn is a `.NestedMap` with matching keys and shapes of
this 'extras'.
cell_type: Cell type used in this class.
accumulator_layer: If provided, then accumulators on this layer will be
managed such that they carry to the final state in `FProp` and are
disabled for gradients. Uses the state key `accumulators`.
implicit_captures: A `.NestedMap` corresponding to implicit captures of
the cell_fn. If empty/None, implicit captures are either not present or
disallowed.
unused_acc_state: If None, we assume every field of acc_state is consumed
in the following timestamps. If True, None of the acc_state is consumed.
And we reduce_sum each timestep's new state into a scalar. Note, this
feature should be used with StackedRecurrent where we send out the new
state to the other devices.
backward_cleanup: An optional callback function (no argument) to be
invoked after the backward pass. It returns a list of ops, which will
run as control dependencies of d(inputs) on the backward path. Could be
used to clean up side effects during recompute.
"""
self._theta = theta
self._state = state0
self._inputs = inputs
self._cell_fn = _DecorateCellFn(cell_fn, accumulator_layer)
self._cell_grad = _DecorateCellGrad(cell_grad, accumulator_layer)
self._stop_fn = stop_fn
self._extras = extras
if cell_type is not None:
self._cell_type = cell_type
else:
self._cell_type = 'UnknownType'
self._accumulator_layer = accumulator_layer
self._implicit_captures = implicit_captures
self._unused_acc_state = unused_acc_state
self._backward_cleanup = backward_cleanup
# NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
# Forward and Backward defined below) simply takes a list of
# Tensors and returns a list of Tensors. When we pass in a
# structure (a list of NestedMap of Tensors), we use Flatten to
# convert the structure into a list of tensor. Conversely, the
# following code often uses Pack to formulate a structure from a
# list of tensors based on a "template".
compiled = py_utils.use_xla()
noinline = not compiled
# state1, extras = cell_fn(theta, state0, inputs)
def Fwd(theta, state0, inputs):
py_utils.SetShapes(theta, self._theta)
state1, extras = self._cell_fn(theta, state0, inputs)
py_utils.AssertIsCompatible(state1, self._state)
py_utils.AssertIsCompatible(extras, self._extras)
return state1, extras
# Wraps cell_fn in a TF Function as a for-loop's body.
#
# The loop state is composed of:
# t: The loop variable on the device. Timestep id.
# theta: the recurrent net's weights.
# state0: the previous recurrent state.
# inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
# acc_state: Each timestep's computed new state is also stashed into
# acc_state.
# acc_extras: Each timestep's computed extras is stashed into acc_extras
def ForwardLoopCond(loop_state):
"""The condition of forward loop."""
should_continue = loop_state.t < loop_state.limit
if self._stop_fn:
should_continue = tf.math.logical_and(
should_continue,
tf.reduce_any(
tf.math.logical_not(
self._stop_fn(loop_state.t, loop_state.theta,
loop_state.state0))))
return should_continue
def ForwardLoopBody(loop_state):
"""The body of forward loop."""
t = loop_state.t
# external input at time step t.
inputs_t = _Index(loop_state.inputs, t)
loop_state.state0, extras = Fwd(loop_state.theta, loop_state.state0,
inputs_t)
# Saves state1 and extras in their accumulators.
if not self._unused_acc_state:
loop_state.acc_state = _Update(loop_state.acc_state, loop_state.state0,
t)
loop_state.acc_extras = _Update(loop_state.acc_extras, extras, t)
loop_state.t = tf.add(t, 1)
return loop_state
# Forward calls ForwardLoopBody n times. Each time computes one
# time step of the recurrent net.
def Forward(args):
"""Forward pass of the recurrent net."""
# The sequence length.
pad_begin, pad_end = _SeqPaddingLength(args.inputs)
slen_dim = _SeqLenDim(args.inputs)
limit = slen_dim - pad_end
# Creates accumulators for state0 and extras.
if self._unused_acc_state:
acc_state = _EmptyWithFixShape([slen_dim], args.state0)
else:
acc_state = _EmptyAcc(slen_dim, args.state0)
acc_extras = _EmptyAcc(slen_dim, args.extras)
if compiled:
t = tf.cast(pad_begin, tf.int32)
limit = tf.cast(limit, tf.int32)
else:
t = tf.cast(pad_begin, tf.int64)
limit = tf.cast(limit, tf.int64)
with py_utils.RemoveAssertContext(remove=noinline):
run = py_utils.WhileLoop(
ForwardLoopCond,
ForwardLoopBody,
loop_state=py_utils.NestedMap(
t=t,
limit=limit,
theta=args.theta,
state0=args.state0,
inputs=args.inputs,
acc_state=acc_state,
acc_extras=acc_extras))
return py_utils.NestedMap(
limit=run.t,
final_state=run.state0,
acc_state=run.acc_state,
acc_extras=run.acc_extras)
# The per-step backward computes:
# d_theta, d_state0, d_inputs = cell_grad(
# theta, state0, inputs, extras, d_state1)
# where d_state1 is the backprop-ed gradient for state1, and
# extras is the computed by the forward step to facilitate the
# backward step.
def Bak(theta, state0, inputs, extras, d_state1):
"""Backward step."""
py_utils.SetShapes(theta, self._theta)
(dtheta, dstate0, dinputs,
dcaptures) = self._cell_grad(theta, state0, inputs, extras, d_state1)
py_utils.AssertIsCompatible(dtheta, self._theta)
py_utils.AssertIsCompatible(dstate0, self._state)
py_utils.AssertIsCompatible(dinputs, self._inputs)
if dcaptures is None:
# NOTE: Custom gradient fns can return None if they do not support
# captured tensors. The return value is reserved for the future when
# that may be supported.
dcaptures = _EmptyLike(self._implicit_captures)
py_utils.AssertIsCompatible(dcaptures, self._implicit_captures)
# Make sure this function didn't capture anything different than the
# cell_fn when reflected on at the beginning. Must come after the call
# to cell_grad() which adds to the captured list.
_AssertSameTensors(py_utils.GetExtraInputs(),
self._implicit_captures.Flatten())
return [dtheta, dstate0, dinputs, dcaptures]
# Wraps cell_grad gradient function in a TF Function as a
# for-loop's body for the Backward pass.
#
# The loop state is composed of:
# t: The loop variable on the device. Timestep id.
# state0: the initial state for the entire backward loop.
# theta: the recurrent net's weights.
# inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
# acc_state: Each timestep's computed new state was stashed into
# acc_state by the Forward pass.
# acc_extras: Each timestep's computed extras was stashed into
# acc_extras by the Forward pass.
# d_theta: All timestep's gradient for theta is accumulated (added) into
# d_theta.
# d_state1: The backprop-ed gradient for the new stated computed by
# timestep t.
# d_inputs: d_inputs[t, :] is populated by the backward time step t.
# d_acc_state: The backprop-ed gradient for acc_state.
# d_captured: All timestep's gradient for theta is accumulated (added)
# into d_captured.
def BackwardLoopCond(loop_state):
"""Backward loop condition function."""
return loop_state.t >= loop_state.limit
def BackwardLoopBody(loop_state):
"""Backward loop body function."""
t = loop_state.t
# The input recurrent state for time step t is previous time step's
# output, or the original state0 when on time step 0.
state_from_acc = _Index(loop_state.acc_state,
tf.maximum(tf.constant(0, t.dtype), t - 1))
state0 = tf.cond(
tf.equal(t, tf.constant(0, t.dtype)),
true_fn=lambda: loop_state.state0,
false_fn=lambda: state_from_acc)
# The external inputs for time step t.
inputs_t = _Index(loop_state.inputs, t)
# The extras for time step t.
extras_t = _Index(loop_state.acc_extras, t)
d_state1 = _Add(_Index(loop_state.d_acc_state, t), loop_state.d_state1)
(d_theta_t, loop_state.d_state1, d_inputs_t,
d_captured_t) = Bak(loop_state.theta, state0, inputs_t, extras_t,
d_state1)
if self._unused_acc_state:
# XLA IF op requires the same shape for if and else branches.
loop_state.d_state1 = loop_state.d_state1.Transform(tf.reduce_sum)
loop_state.d_theta = _Add(loop_state.d_theta, d_theta_t)
loop_state.d_inputs = _Update(loop_state.d_inputs, d_inputs_t, t)
loop_state.d_captured = _Add(loop_state.d_captured, d_captured_t)
loop_state.t = tf.subtract(t, 1)
# Make sure this function didn't capture anything different than the
# cell_fn when reflected on at the beginning. Must come after the call
# to Bak() which adds to the captured list.
_AssertSameTensors(py_utils.GetExtraInputs(),
self._implicit_captures.Flatten())
return loop_state
# Backward calls BackwardLoopBody n times. Each time computes the backprop
# for one time step of the recurrent net.
def Backward(xs, ys, dys):
"""Backward pass for the recurrent net.
Args:
xs: inputs to the forward operation.
ys: outputs of the forward operation.
dys: gradients to the outputs of the forward operation.
Returns:
Gradients to the inputs of the forward operation.
"""
# Accumulators for gradients.
d_theta = _EmptyLike(xs.theta)
d_inputs = _EmptyLike(xs.inputs)
d_captured = _EmptyLike(self._implicit_captures)
# The sequence length.
pad_begin, _ = _SeqPaddingLength(xs.inputs)
limit = pad_begin
if compiled:
limit = tf.cast(limit, tf.int32)
else:
limit = tf.cast(limit, tf.int64)
state0 = xs.state0
d_state1 = dys.final_state
if self._unused_acc_state:
# XLA While op requires the same shape for the init and carry on
# values.
state0 = state0.Transform(tf.reduce_sum)
d_state1 = d_state1.Transform(tf.reduce_sum)
with py_utils.RemoveAssertContext(remove=noinline):
run = py_utils.WhileLoop(
cond=BackwardLoopCond,
body=BackwardLoopBody,
loop_state=py_utils.NestedMap(
t=ys.limit - 1,
limit=limit,
theta=xs.theta,
state0=state0,
inputs=xs.inputs,
acc_state=ys.acc_state,
acc_extras=ys.acc_extras,
d_theta=d_theta,
d_state1=d_state1,
d_inputs=d_inputs,
d_acc_state=dys.acc_state,
d_captured=d_captured))
d_state0 = run.d_state1
if self._unused_acc_state:
# Match the shape of gradient of the init_state.
d_state0 = self._state.Transform(tf.zeros_like)
d_inputs = run.d_inputs
if self._backward_cleanup is not None:
with tf.control_dependencies(d_inputs.Flatten()):
control_before = tf.no_op()
with tf.control_dependencies([control_before]):
with tf.control_dependencies(self._backward_cleanup()):
d_inputs = d_inputs.Transform(tf.identity)
# The `extra` input in the Forward function is actually an output of the
# function. It was supplied as an input only to create acc_extras with
# proper shape, so its gradients should be zero.
return py_utils.NestedMap(
d_theta=run.d_theta,
d_state0=d_state0,
d_inputs=d_inputs,
d_extras=_EmptyLike(self._extras)), run.d_captured
# Forward arguments.
self._fwd_args = py_utils.NestedMap(
theta=self._theta,
state0=self._state,
inputs=self._inputs,
extras=self._extras)
# pylint: disable=protected-access
device_funcs = tf.get_default_graph()._device_functions_outer_to_inner
self._caller_device = device_funcs[-1] if device_funcs else None
# pylint: enable=protected-access
self._forward = Forward
self._backward = Backward