in tensorflow_probability/python/experimental/mcmc/nuts_autobatching.py [0:0]
def _make_evolve_trajectory(value_and_gradients_fn, max_depth,
unrolled_leapfrog_steps, seed_stream):
"""Constructs an auto-batched NUTS trajectory evolver.
This indirection with an explicit maker function is necessary because the
auto-batching system this uses doesn't understand non-Tensor variables.
Consequently, `target_log_prob_fn` and `seed_stream` have to be passed through
the lexical context.
The returned trajectory evolver will invoke `target_log_prob_fn` as many times
as requested by the longest trajectory.
Args:
value_and_gradients_fn: Python callable which takes arguments like
`*current_state` and returns a batch of its (possibly unnormalized)
log-densities under the target distribution, and the gradients thereof.
max_depth: Maximum depth of the recursion tree, in *edges*.
unrolled_leapfrog_steps: Number of leapfrogs to unroll per tree extension
step.
seed_stream: Mutable random number generator.
Returns:
evolve_trajectory: Function for running the trajectory evolution.
"""
ctx = ab.Context()
def many_steps_type(args):
_, state_type, prob_type, grad_type, _, leapfrogs_type = args
return (state_type, prob_type, grad_type), leapfrogs_type
@ctx.batch(type_inference=many_steps_type)
def many_steps(
num_steps,
current_state,
current_target_log_prob,
current_grads_log_prob,
step_size,
leapfrogs):
"""Runs `evolve_trajectory` the requested number of times sequentially."""
current_momentum, log_slice_sample = _start_trajectory_batched(
current_state, current_target_log_prob, seed_stream)
current = Point(
current_state, current_target_log_prob,
current_grads_log_prob, current_momentum)
if truthy(num_steps > 0):
next_, new_leapfrogs = evolve_trajectory(
current,
current,
current,
step_size,
log_slice_sample,
tf.constant([0], dtype=tf.int64), # depth
tf.constant([1], dtype=tf.int64), # num_states
tf.constant([0], dtype=tf.int64), # leapfrogs_taken
True) # continue_trajectory
return many_steps(
num_steps - 1,
next_.state,
next_.target_log_prob,
next_.grads_target_log_prob,
step_size,
leapfrogs + new_leapfrogs)
else:
return ((current.state, current.target_log_prob,
current.grads_target_log_prob), leapfrogs)
def evolve_trajectory_type(args):
point_type, _, _, _, _, _, _, leapfrogs_type, _ = args
return point_type, leapfrogs_type
@ctx.batch(type_inference=evolve_trajectory_type)
def evolve_trajectory(
reverse,
forward,
next_,
step_size,
log_slice_sample,
depth,
num_states,
leapfrogs,
continue_trajectory):
"""Evolves one NUTS trajectory in progress until a U-turn is encountered.
This function is coded for one NUTS chain, and automatically batched to
support several. The argument descriptions below are written in
single-chain language.
This function only exists because the auto-batching system does not (yet)
support syntactic while loops. It implements a while loop by calling
itself at the end.
Args:
reverse: `Point` tuple of `Tensor`s representing the "reverse" states of
the NUTS trajectory.
forward: `Point` tuple of `Tensor`s representing the "forward" states of
the NUTS trajectory. Has same shape as `reverse`.
next_: `Point` tuple of `Tensor`s representing the next states of the
NUTS trajectory. Has same shape as `reverse`.
step_size: List of `Tensor`s representing the step sizes for the
leapfrog integrator. Must have same shape as `current_state`.
log_slice_sample: The log of an auxiliary slice variable. It is used to
pay for the posterior value at traversed states to avoid a Metropolis
correction at the end.
depth: non-negative integer Tensor that indicates how deep of a tree to
build at the next trajectory doubling.
num_states: Number of acceptable candidate states in the initial tree
built so far. A state is acceptable if it is "in the slice", that is,
if its log-joint probability with its momentum is greater than
`log_slice_sample`.
leapfrogs: Number of leapfrog steps computed so far.
continue_trajectory: bool determining whether to continue the simulation
trajectory. The trajectory is continued if no U-turns are encountered
within the built subtree, and if the log-probability accumulation due
to integration error does not exceed `max_simulation_error`.
Returns:
next_: `Point` tuple of `Tensor`s representing the state this NUTS
trajectory transitions to. Has same shape as `reverse`.
leapfrogs: Number of leapfrog steps computed in the trajectory, as a
diagnostic.
"""
if truthy(continue_trajectory):
# Grow the No-U-Turn Sampler trajectory by choosing a random direction
# and simulating Hamiltonian dynamics in that direction. This extends
# either the forward or reverse state.
direction = _choose_direction_batched(forward, seed_stream)
into_build_tree = _tf_where(direction < 0, reverse, forward)
[
reverse_out,
forward_out,
next_in_subtree,
num_states_in_subtree,
more_leapfrogs,
continue_trajectory,
] = _build_tree(
into_build_tree, direction, depth, step_size, log_slice_sample)
# TODO(b/122732601): Revert back to `if` when the compiler makes the xform
reverse_in = reverse
reverse = _tf_where(direction < 0, reverse_out, reverse_in)
forward_in = forward
forward = _tf_where(direction < 0, forward_in, forward_out)
# TODO(b/122732601): Revert back to `if` when the compiler makes the xform
# If the built tree did not terminate, accept the tree's next state
# with a certain probability.
accept_state_in_subtree = _binomial_subtree_acceptance_batched(
num_states_in_subtree, num_states, seed_stream)
next_in = next_
next_ = _tf_where(continue_trajectory & accept_state_in_subtree,
next_in_subtree, next_in)
# Continue the NUTS trajectory if the tree-building did not terminate,
# and if the reverse-most and forward-most states do not exhibit a
# U-turn.
continue_trajectory_in = continue_trajectory
continue_trajectory = _continue_test_batched(
continue_trajectory_in & (depth < max_depth), forward, reverse)
return evolve_trajectory(
reverse,
forward,
next_,
step_size,
log_slice_sample,
depth + 1,
num_states + num_states_in_subtree,
leapfrogs + more_leapfrogs,
continue_trajectory)
else:
return next_, leapfrogs
def _build_tree_type(args):
point_type, _, _, _, _ = args
return (point_type, point_type, point_type,
ab.TensorType(np.int64, ()), ab.TensorType(np.int64, ()),
ab.TensorType(np.bool_, ()))
@ctx.batch(type_inference=_build_tree_type)
def _build_tree(current, direction, depth, step_size, log_slice_sample):
"""Builds a tree at a given tree depth and at a given state.
The `current` state is immediately adjacent to, but outside of,
the subtrajectory spanned by the returned `forward` and `reverse` states.
This function is coded for one NUTS chain, and automatically batched to
support several. The argument descriptions below are written in
single-chain language.
Args:
current: `Point` tuple of `Tensor`s representing the current states of
the NUTS trajectory.
direction: Integer Tensor that is either -1 or 1. It determines whether
to perform leapfrog integration backward (reverse) or forward in time
respectively.
depth: non-negative integer Tensor that indicates how deep of a tree to
build. Each call to `_build_tree` takes `2**depth` leapfrog steps,
unless stopped early by detecting a U-turn.
step_size: List of `Tensor`s representing the step sizes for the
leapfrog integrator. Must have same shape as `current_state`.
log_slice_sample: The log of an auxiliary slice variable. It is used to
pay for the posterior value at traversed states to avoid a Metropolis
correction at the end.
Returns:
reverse: `Point` tuple of `Tensor`s representing the state at the
extreme "backward in time" point of the simulated subtrajectory. Has
same shape as `current`.
forward: `Point` tuple of `Tensor`s representing the state at the
extreme "forward in time" point of the simulated subtrajectory. Has
same shape as `current`.
next_: `Point` tuple of `Tensor`s representing the candidate point to
transition to, sampled from this subtree. Has same shape as
`current_state`.
num_states: Number of acceptable candidate states in the subtree. A
state is acceptable if it is "in the slice", that is, if its log-joint
probability with its momentum is greater than `log_slice_sample`.
leapfrogs: Number of leapfrog steps computed in this subtree, as a
diagnostic.
continue_trajectory: bool determining whether to continue the simulation
trajectory. The trajectory is continued if no U-turns are encountered
within the built subtree, and if the log-probability accumulation due
to integration error does not exceed `max_simulation_error`.
"""
if truthy(depth > 0): # Recursive case
# Build a tree at the current state.
(reverse, forward, next_,
num_states, leapfrogs, continue_trajectory) = _build_tree(
current, direction, depth - 1, step_size, log_slice_sample)
more_leapfrogs = 0
if truthy(continue_trajectory):
# If the just-built subtree did not terminate, build a second subtree
# at the forward or reverse state, as appropriate.
# TODO(b/122732601): Revert back to `if` when compiler makes the xform
in_ = _tf_where(direction < 0, reverse, forward)
(reverse_out, forward_out, far,
far_num_states, more_leapfrogs, far_continue) = _build_tree(
in_, direction, depth - 1, step_size, log_slice_sample)
reverse_in = reverse
reverse = _tf_where(direction < 0, reverse_out, reverse_in)
forward_in = forward
forward = _tf_where(direction < 0, forward_in, forward_out)
# Propose either `next_` (which came from the first subtree and
# so is nearby) or the new forward/reverse state (which came from the
# second subtree and so is far away).
num_states_old = num_states
num_states = num_states_old + far_num_states
accept_far_state = _binomial_subtree_acceptance_batched(
far_num_states, num_states, seed_stream)
# TODO(b/122732601): Revert back to `if` when compiler makes the xform
next_in = next_
next_ = _tf_where(accept_far_state, far, next_in)
# Continue the NUTS trajectory if the far subtree did not terminate
# either, and if the reverse-most and forward-most states do not
# exhibit a U-turn.
continue_trajectory = _continue_test_batched(
far_continue, forward, reverse)
return (reverse, forward, next_,
num_states, leapfrogs + more_leapfrogs, continue_trajectory)
else: # Base case
# Take a leapfrog step. Terminate the tree-building if the simulation
# error from the leapfrog integrator is too large. States discovered by
# continuing the simulation are likely to have very low probability.
next_ = _leapfrog(
value_and_gradients_fn=value_and_gradients_fn,
current=current,
step_size=step_size,
direction=direction,
unrolled_leapfrog_steps=unrolled_leapfrog_steps)
next_log_joint = _log_joint(next_)
num_states = _compute_num_states_batched(
next_log_joint, log_slice_sample)
# This 1000 is the max_simulation_error. Inlined instead of named so
# TensorFlow can infer its dtype from context, b/c the type inference in
# the auto-batching system gets confused. TODO(axch): Re-extract.
continue_trajectory = (next_log_joint > log_slice_sample - 1000.)
return (next_, next_, next_, num_states, unrolled_leapfrog_steps,
continue_trajectory)
return many_steps, ctx