def recurrent_func()

in lingvo/jax/layers/recurrent.py [0:0]


def recurrent_func(theta: NestedMap, states_0: NestedMap, inputs: NestedMap,
                   cell_fn: Callable[[NestedMap, NestedMap, NestedMap],
                                     NestedMap]):
  """Computes a recurrent neural net.

  Args:
    theta: weights. A `.NestedMap`.
    states_0: initial state. A `.NestedMap`.
    inputs: inputs. A `.NestedMap`.
    cell_fn: A python function which computes::
        states_1 = cell_fn(theta, states_0, inputs[t, :])

  Returns:
    `accumulate_state` and the final state.
  """
  input_seq_len = inputs.Flatten()[0].shape[0]

  def assert_not_none(x):
    assert x is not None

  tf.nest.map_structure(assert_not_none, states_0)
  tf.nest.map_structure(assert_not_none, inputs)
  tf.nest.map_structure(assert_not_none, theta)

  def new_cum_state(x):
    x1 = jnp.expand_dims(x, 0)
    # +1 so that we can store initial_states at position 0.
    return jnp.tile(x1, [input_seq_len + 1] + [1] * x.ndim)

  cumulative_states = states_0.Transform(new_cum_state)

  prng_key = base_layer.next_prng_key()
  global_step = base_layer.cur_global_step()

  start_time = jnp.array(0, dtype=jnp.uint32)
  fwd_initial_loop_vars = NestedMap(
      cur_time=start_time,
      theta=theta,
      states_0=states_0,
      cumulative_states=cumulative_states,
      inputs=inputs)

  def same_type_shape(x, y):
    assert x.dtype == y.dtype, (x.dtype, y.dtype)
    assert x.shape == y.shape, (x.shape, y.shape)

  def wrapped_cell_fn(fn_in):
    # fn_in is NestedMap containing the following elements:
    #    - t
    #    - theta
    #    - states_0
    #    - inputs_t
    # Start a chain of prng key that also takes into account of time steps.
    t = fn_in.t
    theta = fn_in.theta
    states_0 = fn_in.states_0
    inputs_t = fn_in.inputs_t
    with base_layer.JaxContext.new_context(
        prng_key=jax.random.fold_in(prng_key, t), global_step=global_step):
      # NO side-effect ops are allowed as the enclosing JaxContext is not bound
      # to any layer.
      states_1 = cell_fn(theta, states_0, inputs_t)

      tf.nest.assert_same_structure(states_0, states_1)
      tf.nest.map_structure(same_type_shape, states_0, states_1)
    return states_1

  def wrapped_cell_fn_grad(fn_in, d_fn_out):
    # This is roughly the following:
    #
    # fn_out = wrapped_cell_fn(fn_in)
    # d_fn_in = tf.gradient(fn_out, fn_in, d_fn_out)
    # return d_fn_in
    #
    assert isinstance(fn_in, NestedMap)
    fn_out, vjp_fn = jax.vjp(wrapped_cell_fn, fn_in)
    del fn_out
    d_fn_in = vjp_fn(d_fn_out)
    assert isinstance(d_fn_in, tuple)
    assert len(d_fn_in) == 1
    d_fn_in_0 = d_fn_in[0]
    # Over-write gradient for t, the time step.
    d_fn_in_0.t = jnp.zeros_like(fn_in.t)
    tf.nest.assert_same_structure(fn_in, d_fn_in_0)
    tf.nest.map_structure(same_type_shape, fn_in, d_fn_in_0)
    return d_fn_in_0

  def fwd_comp_fn(loop_vars):
    # loop_vars is a NestedMap containing the following elements:
    #   - cur_time
    #   - theta
    #   - inputs
    #   - cumulative_states
    #   - states_0
    t = loop_vars.cur_time
    theta = loop_vars.theta
    inputs = loop_vars.inputs
    cumulative_states = loop_vars.cumulative_states
    states_0 = loop_vars.states_0
    inputs_t = inputs.Transform(lambda x: x[t])

    states_1 = wrapped_cell_fn(
        NestedMap(t=t, theta=theta, states_0=states_0, inputs_t=inputs_t))

    def set_t(x, x_t):
      return x.at[t + 1].set(x_t)

    cumulative_states = tf.nest.map_structure(set_t, cumulative_states,
                                              states_1)
    loop_out = NestedMap(
        cur_time=t + 1,
        theta=theta,
        inputs=inputs,
        states_0=states_1,
        cumulative_states=cumulative_states)
    return loop_out

  def fwd_continue_fn(loop_vars):
    return loop_vars.cur_time < input_seq_len

  # This custom_vjp implementation follows examples here:
  # https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
  @jax.custom_vjp
  def fwd_loop(loop_vars):
    final_loop_vars = jax.lax.while_loop(fwd_continue_fn, fwd_comp_fn,
                                         loop_vars)
    return NestedMap(
        final_states=final_loop_vars.states_0,
        cumulative_states=final_loop_vars.cumulative_states)

  def loop_fn_vjp_fwd(loop_vars):
    loop_fn_out = fwd_loop(loop_vars)
    return loop_fn_out, (loop_vars, loop_fn_out.cumulative_states)

  def loop_fn_vjp_bwd(res, d_out):
    fwd_loop_vars, cumulative_states = res
    d_final_states = d_out.final_states
    d_cumulative_states = d_out.cumulative_states

    start_time = input_seq_len - 1
    d_states_1 = tf.nest.map_structure(lambda x, y: x[start_time + 1] + y,
                                       d_cumulative_states, d_final_states)
    bwd_loop_vars = NestedMap(
        cur_time=start_time,
        theta=fwd_loop_vars.theta,
        inputs=fwd_loop_vars.inputs,
        cumulative_states=cumulative_states,
        d_cumulative_states=d_cumulative_states,
        d_theta=fwd_loop_vars.theta.Transform(jnp.zeros_like),
        d_inputs=fwd_loop_vars.inputs.Transform(jnp.zeros_like),
        d_states_1=d_states_1)

    def bwd_comp_fn(loop_vars):
      t = loop_vars.cur_time
      inputs = loop_vars.inputs
      inputs_t = inputs.Transform(lambda x: x[t])
      states_0 = loop_vars.cumulative_states.Transform(lambda x: x[t])
      d_cell_in = wrapped_cell_fn_grad(
          NestedMap(
              t=t, theta=loop_vars.theta, states_0=states_0, inputs_t=inputs_t),
          loop_vars.d_states_1)
      d_theta = tf.nest.map_structure(lambda x, y: x + y, loop_vars.d_theta,
                                      d_cell_in.theta)
      d_states_0 = tf.nest.map_structure(lambda x, y: x + y[t],
                                         d_cell_in.states_0,
                                         loop_vars.d_cumulative_states)

      def set_t(x, x_t):
        return x.at[t].set(x_t)

      d_inputs = tf.nest.map_structure(set_t, loop_vars.d_inputs,
                                       d_cell_in.inputs_t)
      loop_vars_out = loop_vars.Transform(lambda x: x)
      loop_vars_out.d_inputs = d_inputs
      loop_vars_out.d_states_1 = d_states_0
      loop_vars_out.d_theta = d_theta
      loop_vars_out.cur_time = t - 1
      return loop_vars_out

    def bwd_continue_fn(loop_vars):
      return loop_vars.cur_time >= 0

    bwd_final_loop_vars = jax.lax.while_loop(bwd_continue_fn, bwd_comp_fn,
                                             bwd_loop_vars)
    d_out = fwd_loop_vars.Transform(jnp.zeros_like)

    tf.nest.map_structure(same_type_shape, d_out.states_0,
                          bwd_final_loop_vars.d_states_1)
    tf.nest.map_structure(same_type_shape, d_out.theta,
                          bwd_final_loop_vars.d_theta)
    tf.nest.map_structure(same_type_shape, d_out.inputs,
                          bwd_final_loop_vars.d_inputs)

    d_out.states_0 = bwd_final_loop_vars.d_states_1
    d_out.theta = bwd_final_loop_vars.d_theta
    d_out.inputs = bwd_final_loop_vars.d_inputs
    return (d_out,)

  fwd_loop.defvjp(loop_fn_vjp_fwd, loop_fn_vjp_bwd)

  # Finally, let's simply run the forward loop fn.
  fwd_final_loop_vars = fwd_loop(fwd_initial_loop_vars)
  fwd_cumulative_states = fwd_final_loop_vars.cumulative_states.Transform(
      lambda x: x[1:])
  return fwd_final_loop_vars.final_states, fwd_cumulative_states