in lingvo/jax/layers/recurrent.py [0:0]
def recurrent_func(theta: NestedMap, states_0: NestedMap, inputs: NestedMap,
cell_fn: Callable[[NestedMap, NestedMap, NestedMap],
NestedMap]):
"""Computes a recurrent neural net.
Args:
theta: weights. A `.NestedMap`.
states_0: initial state. A `.NestedMap`.
inputs: inputs. A `.NestedMap`.
cell_fn: A python function which computes::
states_1 = cell_fn(theta, states_0, inputs[t, :])
Returns:
`accumulate_state` and the final state.
"""
input_seq_len = inputs.Flatten()[0].shape[0]
def assert_not_none(x):
assert x is not None
tf.nest.map_structure(assert_not_none, states_0)
tf.nest.map_structure(assert_not_none, inputs)
tf.nest.map_structure(assert_not_none, theta)
def new_cum_state(x):
x1 = jnp.expand_dims(x, 0)
# +1 so that we can store initial_states at position 0.
return jnp.tile(x1, [input_seq_len + 1] + [1] * x.ndim)
cumulative_states = states_0.Transform(new_cum_state)
prng_key = base_layer.next_prng_key()
global_step = base_layer.cur_global_step()
start_time = jnp.array(0, dtype=jnp.uint32)
fwd_initial_loop_vars = NestedMap(
cur_time=start_time,
theta=theta,
states_0=states_0,
cumulative_states=cumulative_states,
inputs=inputs)
def same_type_shape(x, y):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
assert x.shape == y.shape, (x.shape, y.shape)
def wrapped_cell_fn(fn_in):
# fn_in is NestedMap containing the following elements:
# - t
# - theta
# - states_0
# - inputs_t
# Start a chain of prng key that also takes into account of time steps.
t = fn_in.t
theta = fn_in.theta
states_0 = fn_in.states_0
inputs_t = fn_in.inputs_t
with base_layer.JaxContext.new_context(
prng_key=jax.random.fold_in(prng_key, t), global_step=global_step):
# NO side-effect ops are allowed as the enclosing JaxContext is not bound
# to any layer.
states_1 = cell_fn(theta, states_0, inputs_t)
tf.nest.assert_same_structure(states_0, states_1)
tf.nest.map_structure(same_type_shape, states_0, states_1)
return states_1
def wrapped_cell_fn_grad(fn_in, d_fn_out):
# This is roughly the following:
#
# fn_out = wrapped_cell_fn(fn_in)
# d_fn_in = tf.gradient(fn_out, fn_in, d_fn_out)
# return d_fn_in
#
assert isinstance(fn_in, NestedMap)
fn_out, vjp_fn = jax.vjp(wrapped_cell_fn, fn_in)
del fn_out
d_fn_in = vjp_fn(d_fn_out)
assert isinstance(d_fn_in, tuple)
assert len(d_fn_in) == 1
d_fn_in_0 = d_fn_in[0]
# Over-write gradient for t, the time step.
d_fn_in_0.t = jnp.zeros_like(fn_in.t)
tf.nest.assert_same_structure(fn_in, d_fn_in_0)
tf.nest.map_structure(same_type_shape, fn_in, d_fn_in_0)
return d_fn_in_0
def fwd_comp_fn(loop_vars):
# loop_vars is a NestedMap containing the following elements:
# - cur_time
# - theta
# - inputs
# - cumulative_states
# - states_0
t = loop_vars.cur_time
theta = loop_vars.theta
inputs = loop_vars.inputs
cumulative_states = loop_vars.cumulative_states
states_0 = loop_vars.states_0
inputs_t = inputs.Transform(lambda x: x[t])
states_1 = wrapped_cell_fn(
NestedMap(t=t, theta=theta, states_0=states_0, inputs_t=inputs_t))
def set_t(x, x_t):
return x.at[t + 1].set(x_t)
cumulative_states = tf.nest.map_structure(set_t, cumulative_states,
states_1)
loop_out = NestedMap(
cur_time=t + 1,
theta=theta,
inputs=inputs,
states_0=states_1,
cumulative_states=cumulative_states)
return loop_out
def fwd_continue_fn(loop_vars):
return loop_vars.cur_time < input_seq_len
# This custom_vjp implementation follows examples here:
# https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
@jax.custom_vjp
def fwd_loop(loop_vars):
final_loop_vars = jax.lax.while_loop(fwd_continue_fn, fwd_comp_fn,
loop_vars)
return NestedMap(
final_states=final_loop_vars.states_0,
cumulative_states=final_loop_vars.cumulative_states)
def loop_fn_vjp_fwd(loop_vars):
loop_fn_out = fwd_loop(loop_vars)
return loop_fn_out, (loop_vars, loop_fn_out.cumulative_states)
def loop_fn_vjp_bwd(res, d_out):
fwd_loop_vars, cumulative_states = res
d_final_states = d_out.final_states
d_cumulative_states = d_out.cumulative_states
start_time = input_seq_len - 1
d_states_1 = tf.nest.map_structure(lambda x, y: x[start_time + 1] + y,
d_cumulative_states, d_final_states)
bwd_loop_vars = NestedMap(
cur_time=start_time,
theta=fwd_loop_vars.theta,
inputs=fwd_loop_vars.inputs,
cumulative_states=cumulative_states,
d_cumulative_states=d_cumulative_states,
d_theta=fwd_loop_vars.theta.Transform(jnp.zeros_like),
d_inputs=fwd_loop_vars.inputs.Transform(jnp.zeros_like),
d_states_1=d_states_1)
def bwd_comp_fn(loop_vars):
t = loop_vars.cur_time
inputs = loop_vars.inputs
inputs_t = inputs.Transform(lambda x: x[t])
states_0 = loop_vars.cumulative_states.Transform(lambda x: x[t])
d_cell_in = wrapped_cell_fn_grad(
NestedMap(
t=t, theta=loop_vars.theta, states_0=states_0, inputs_t=inputs_t),
loop_vars.d_states_1)
d_theta = tf.nest.map_structure(lambda x, y: x + y, loop_vars.d_theta,
d_cell_in.theta)
d_states_0 = tf.nest.map_structure(lambda x, y: x + y[t],
d_cell_in.states_0,
loop_vars.d_cumulative_states)
def set_t(x, x_t):
return x.at[t].set(x_t)
d_inputs = tf.nest.map_structure(set_t, loop_vars.d_inputs,
d_cell_in.inputs_t)
loop_vars_out = loop_vars.Transform(lambda x: x)
loop_vars_out.d_inputs = d_inputs
loop_vars_out.d_states_1 = d_states_0
loop_vars_out.d_theta = d_theta
loop_vars_out.cur_time = t - 1
return loop_vars_out
def bwd_continue_fn(loop_vars):
return loop_vars.cur_time >= 0
bwd_final_loop_vars = jax.lax.while_loop(bwd_continue_fn, bwd_comp_fn,
bwd_loop_vars)
d_out = fwd_loop_vars.Transform(jnp.zeros_like)
tf.nest.map_structure(same_type_shape, d_out.states_0,
bwd_final_loop_vars.d_states_1)
tf.nest.map_structure(same_type_shape, d_out.theta,
bwd_final_loop_vars.d_theta)
tf.nest.map_structure(same_type_shape, d_out.inputs,
bwd_final_loop_vars.d_inputs)
d_out.states_0 = bwd_final_loop_vars.d_states_1
d_out.theta = bwd_final_loop_vars.d_theta
d_out.inputs = bwd_final_loop_vars.d_inputs
return (d_out,)
fwd_loop.defvjp(loop_fn_vjp_fwd, loop_fn_vjp_bwd)
# Finally, let's simply run the forward loop fn.
fwd_final_loop_vars = fwd_loop(fwd_initial_loop_vars)
fwd_cumulative_states = fwd_final_loop_vars.cumulative_states.Transform(
lambda x: x[1:])
return fwd_final_loop_vars.final_states, fwd_cumulative_states