def _make_evolve_trajectory()

in tensorflow_probability/python/experimental/mcmc/nuts_autobatching.py [0:0]


def _make_evolve_trajectory(value_and_gradients_fn, max_depth,
                            unrolled_leapfrog_steps, seed_stream):
  """Constructs an auto-batched NUTS trajectory evolver.

  This indirection with an explicit maker function is necessary because the
  auto-batching system this uses doesn't understand non-Tensor variables.
  Consequently, `target_log_prob_fn` and `seed_stream` have to be passed through
  the lexical context.

  The returned trajectory evolver will invoke `target_log_prob_fn` as many times
  as requested by the longest trajectory.

  Args:
    value_and_gradients_fn: Python callable which takes arguments like
      `*current_state` and returns a batch of its (possibly unnormalized)
      log-densities under the target distribution, and the gradients thereof.
    max_depth: Maximum depth of the recursion tree, in *edges*.
    unrolled_leapfrog_steps: Number of leapfrogs to unroll per tree extension
      step.
    seed_stream: Mutable random number generator.

  Returns:
    evolve_trajectory: Function for running the trajectory evolution.
  """
  ctx = ab.Context()

  def many_steps_type(args):
    _, state_type, prob_type, grad_type, _, leapfrogs_type = args
    return (state_type, prob_type, grad_type), leapfrogs_type

  @ctx.batch(type_inference=many_steps_type)
  def many_steps(
      num_steps,
      current_state,
      current_target_log_prob,
      current_grads_log_prob,
      step_size,
      leapfrogs):
    """Runs `evolve_trajectory` the requested number of times sequentially."""
    current_momentum, log_slice_sample = _start_trajectory_batched(
        current_state, current_target_log_prob, seed_stream)

    current = Point(
        current_state, current_target_log_prob,
        current_grads_log_prob, current_momentum)

    if truthy(num_steps > 0):
      next_, new_leapfrogs = evolve_trajectory(
          current,
          current,
          current,
          step_size,
          log_slice_sample,
          tf.constant([0], dtype=tf.int64),  # depth
          tf.constant([1], dtype=tf.int64),  # num_states
          tf.constant([0], dtype=tf.int64),  # leapfrogs_taken
          True)  # continue_trajectory
      return many_steps(
          num_steps - 1,
          next_.state,
          next_.target_log_prob,
          next_.grads_target_log_prob,
          step_size,
          leapfrogs + new_leapfrogs)
    else:
      return ((current.state, current.target_log_prob,
               current.grads_target_log_prob), leapfrogs)

  def evolve_trajectory_type(args):
    point_type, _, _, _, _, _, _, leapfrogs_type, _ = args
    return point_type, leapfrogs_type

  @ctx.batch(type_inference=evolve_trajectory_type)
  def evolve_trajectory(
      reverse,
      forward,
      next_,
      step_size,
      log_slice_sample,
      depth,
      num_states,
      leapfrogs,
      continue_trajectory):
    """Evolves one NUTS trajectory in progress until a U-turn is encountered.

    This function is coded for one NUTS chain, and automatically batched to
    support several.  The argument descriptions below are written in
    single-chain language.

    This function only exists because the auto-batching system does not (yet)
    support syntactic while loops.  It implements a while loop by calling
    itself at the end.

    Args:
      reverse: `Point` tuple of `Tensor`s representing the "reverse" states of
        the NUTS trajectory.
      forward: `Point` tuple of `Tensor`s representing the "forward" states of
        the NUTS trajectory. Has same shape as `reverse`.
      next_: `Point` tuple of `Tensor`s representing the next states of the
        NUTS trajectory. Has same shape as `reverse`.
      step_size: List of `Tensor`s representing the step sizes for the
        leapfrog integrator. Must have same shape as `current_state`.
      log_slice_sample: The log of an auxiliary slice variable. It is used to
        pay for the posterior value at traversed states to avoid a Metropolis
        correction at the end.
      depth: non-negative integer Tensor that indicates how deep of a tree to
        build at the next trajectory doubling.
      num_states: Number of acceptable candidate states in the initial tree
        built so far. A state is acceptable if it is "in the slice", that is,
        if its log-joint probability with its momentum is greater than
        `log_slice_sample`.
      leapfrogs: Number of leapfrog steps computed so far.
      continue_trajectory: bool determining whether to continue the simulation
        trajectory. The trajectory is continued if no U-turns are encountered
        within the built subtree, and if the log-probability accumulation due
        to integration error does not exceed `max_simulation_error`.

    Returns:
      next_: `Point` tuple of `Tensor`s representing the state this NUTS
        trajectory transitions to.  Has same shape as `reverse`.
      leapfrogs: Number of leapfrog steps computed in the trajectory, as a
        diagnostic.
    """
    if truthy(continue_trajectory):
      # Grow the No-U-Turn Sampler trajectory by choosing a random direction
      # and simulating Hamiltonian dynamics in that direction. This extends
      # either the forward or reverse state.
      direction = _choose_direction_batched(forward, seed_stream)
      into_build_tree = _tf_where(direction < 0, reverse, forward)
      [
          reverse_out,
          forward_out,
          next_in_subtree,
          num_states_in_subtree,
          more_leapfrogs,
          continue_trajectory,
      ] = _build_tree(
          into_build_tree, direction, depth, step_size, log_slice_sample)
      # TODO(b/122732601): Revert back to `if` when the compiler makes the xform
      reverse_in = reverse
      reverse = _tf_where(direction < 0, reverse_out, reverse_in)
      forward_in = forward
      forward = _tf_where(direction < 0, forward_in, forward_out)

      # TODO(b/122732601): Revert back to `if` when the compiler makes the xform
      # If the built tree did not terminate, accept the tree's next state
      # with a certain probability.
      accept_state_in_subtree = _binomial_subtree_acceptance_batched(
          num_states_in_subtree, num_states, seed_stream)
      next_in = next_
      next_ = _tf_where(continue_trajectory & accept_state_in_subtree,
                        next_in_subtree, next_in)

      # Continue the NUTS trajectory if the tree-building did not terminate,
      # and if the reverse-most and forward-most states do not exhibit a
      # U-turn.
      continue_trajectory_in = continue_trajectory
      continue_trajectory = _continue_test_batched(
          continue_trajectory_in & (depth < max_depth), forward, reverse)
      return evolve_trajectory(
          reverse,
          forward,
          next_,
          step_size,
          log_slice_sample,
          depth + 1,
          num_states + num_states_in_subtree,
          leapfrogs + more_leapfrogs,
          continue_trajectory)
    else:
      return next_, leapfrogs

  def _build_tree_type(args):
    point_type, _, _, _, _ = args
    return (point_type, point_type, point_type,
            ab.TensorType(np.int64, ()), ab.TensorType(np.int64, ()),
            ab.TensorType(np.bool_, ()))

  @ctx.batch(type_inference=_build_tree_type)
  def _build_tree(current, direction, depth, step_size, log_slice_sample):
    """Builds a tree at a given tree depth and at a given state.

    The `current` state is immediately adjacent to, but outside of,
    the subtrajectory spanned by the returned `forward` and `reverse` states.

    This function is coded for one NUTS chain, and automatically batched to
    support several.  The argument descriptions below are written in
    single-chain language.

    Args:
      current: `Point` tuple of `Tensor`s representing the current states of
        the NUTS trajectory.
      direction: Integer Tensor that is either -1 or 1. It determines whether
        to perform leapfrog integration backward (reverse) or forward in time
        respectively.
      depth: non-negative integer Tensor that indicates how deep of a tree to
        build.  Each call to `_build_tree` takes `2**depth` leapfrog steps,
        unless stopped early by detecting a U-turn.
      step_size: List of `Tensor`s representing the step sizes for the
        leapfrog integrator. Must have same shape as `current_state`.
      log_slice_sample: The log of an auxiliary slice variable. It is used to
        pay for the posterior value at traversed states to avoid a Metropolis
        correction at the end.

    Returns:
      reverse: `Point` tuple of `Tensor`s representing the state at the
        extreme "backward in time" point of the simulated subtrajectory. Has
        same shape as `current`.
      forward: `Point` tuple of `Tensor`s representing the state at the
        extreme "forward in time" point of the simulated subtrajectory. Has
        same shape as `current`.
      next_: `Point` tuple of `Tensor`s representing the candidate point to
        transition to, sampled from this subtree. Has same shape as
        `current_state`.
      num_states: Number of acceptable candidate states in the subtree. A
        state is acceptable if it is "in the slice", that is, if its log-joint
        probability with its momentum is greater than `log_slice_sample`.
      leapfrogs: Number of leapfrog steps computed in this subtree, as a
        diagnostic.
      continue_trajectory: bool determining whether to continue the simulation
        trajectory. The trajectory is continued if no U-turns are encountered
        within the built subtree, and if the log-probability accumulation due
        to integration error does not exceed `max_simulation_error`.
    """
    if truthy(depth > 0):  # Recursive case
      # Build a tree at the current state.
      (reverse, forward, next_,
       num_states, leapfrogs, continue_trajectory) = _build_tree(
           current, direction, depth - 1, step_size, log_slice_sample)
      more_leapfrogs = 0
      if truthy(continue_trajectory):
        # If the just-built subtree did not terminate, build a second subtree
        # at the forward or reverse state, as appropriate.
        # TODO(b/122732601): Revert back to `if` when compiler makes the xform
        in_ = _tf_where(direction < 0, reverse, forward)
        (reverse_out, forward_out, far,
         far_num_states, more_leapfrogs, far_continue) = _build_tree(
             in_, direction, depth - 1, step_size, log_slice_sample)
        reverse_in = reverse
        reverse = _tf_where(direction < 0, reverse_out, reverse_in)
        forward_in = forward
        forward = _tf_where(direction < 0, forward_in, forward_out)

        # Propose either `next_` (which came from the first subtree and
        # so is nearby) or the new forward/reverse state (which came from the
        # second subtree and so is far away).
        num_states_old = num_states
        num_states = num_states_old + far_num_states
        accept_far_state = _binomial_subtree_acceptance_batched(
            far_num_states, num_states, seed_stream)
        # TODO(b/122732601): Revert back to `if` when compiler makes the xform
        next_in = next_
        next_ = _tf_where(accept_far_state, far, next_in)

        # Continue the NUTS trajectory if the far subtree did not terminate
        # either, and if the reverse-most and forward-most states do not
        # exhibit a U-turn.
        continue_trajectory = _continue_test_batched(
            far_continue, forward, reverse)

      return (reverse, forward, next_,
              num_states, leapfrogs + more_leapfrogs, continue_trajectory)
    else:  # Base case
      # Take a leapfrog step. Terminate the tree-building if the simulation
      # error from the leapfrog integrator is too large. States discovered by
      # continuing the simulation are likely to have very low probability.
      next_ = _leapfrog(
          value_and_gradients_fn=value_and_gradients_fn,
          current=current,
          step_size=step_size,
          direction=direction,
          unrolled_leapfrog_steps=unrolled_leapfrog_steps)
      next_log_joint = _log_joint(next_)
      num_states = _compute_num_states_batched(
          next_log_joint, log_slice_sample)
      # This 1000 is the max_simulation_error.  Inlined instead of named so
      # TensorFlow can infer its dtype from context, b/c the type inference in
      # the auto-batching system gets confused.  TODO(axch): Re-extract.
      continue_trajectory = (next_log_joint > log_slice_sample - 1000.)
      return (next_, next_, next_, num_states, unrolled_leapfrog_steps,
              continue_trajectory)

  return many_steps, ctx