in tensorflow_probability/python/experimental/mcmc/preconditioned_nuts.py [0:0]
def _loop_build_sub_tree(self,
directions,
integrator,
current_step_meta_info,
iter_,
energy_diff_sum_previous,
momentum_cumsum_previous,
leapfrogs_taken,
prev_tree_state,
candidate_tree_state,
continue_tree_previous,
not_divergent_previous,
velocity_state_memory,
momentum_distribution,
seed):
"""Base case in tree doubling."""
acceptance_seed, next_seed = samplers.split_seed(seed)
with tf.name_scope('loop_build_sub_tree'):
# Take one leapfrog step in the direction v and check divergence
kinetic_energy_fn = get_kinetic_energy_fn(momentum_distribution)
[
next_momentum_parts,
next_state_parts,
next_target,
next_target_grad_parts
] = integrator(prev_tree_state.momentum,
prev_tree_state.state,
prev_tree_state.target,
prev_tree_state.target_grad_parts,
kinetic_energy_fn=kinetic_energy_fn)
_, next_velocity_parts = mcmc_util.maybe_call_fn_and_grads(
kinetic_energy_fn, next_momentum_parts)
next_tree_state = TreeDoublingState(
momentum=next_momentum_parts,
velocity=next_velocity_parts,
state=next_state_parts,
target=next_target,
target_grad_parts=next_target_grad_parts)
momentum_cumsum = [p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
next_momentum_parts)]
# If the tree have not yet terminated previously, we count this leapfrog.
leapfrogs_taken = tf.where(
continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken)
write_instruction = current_step_meta_info.write_instruction
read_instruction = current_step_meta_info.read_instruction
init_energy = current_step_meta_info.init_energy
if GENERALIZED_UTURN:
state_to_write = momentum_cumsum_previous
state_to_check = momentum_cumsum
else:
state_to_write = next_state_parts
state_to_check = next_state_parts
batch_shape = ps.shape(next_target)
has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool)
read_index = read_instruction.gather([iter_])[0]
no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda
read_index,
directions,
velocity_state_memory,
next_velocity_parts,
state_to_check,
has_not_u_turn_init,
log_prob_rank=ps.rank(next_target),
shard_axis_names=self.experimental_shard_axis_names)
# Get index to write state into memory swap
write_index = write_instruction.gather([iter_])
velocity_state_memory = VelocityStateSwap(
velocity_swap=[
_safe_tensor_scatter_nd_update(old, [write_index], [new])
for old, new in zip(velocity_state_memory.velocity_swap,
next_velocity_parts)
],
state_swap=[
_safe_tensor_scatter_nd_update(old, [write_index], [new])
for old, new in zip(velocity_state_memory.state_swap,
state_to_write)
])
energy = compute_hamiltonian(next_target, next_momentum_parts,
momentum_distribution)
current_energy = tf.where(tf.math.is_nan(energy),
tf.constant(-np.inf, dtype=energy.dtype),
energy)
energy_diff = current_energy - init_energy
if MULTINOMIAL_SAMPLE:
not_divergent = -energy_diff < self.max_energy_diff
weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff)
log_accept_thresh = energy_diff - weight_sum
else:
log_slice_sample = current_step_meta_info.log_slice_sample
not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
# Uniform sampling on the trajectory within the subtree across valid
# samples.
is_valid = log_slice_sample <= energy_diff
weight_sum = tf.where(is_valid,
candidate_tree_state.weight + 1,
candidate_tree_state.weight)
log_accept_thresh = tf.where(
is_valid,
-tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
tf.constant(-np.inf, dtype=tf.float32))
u = tf.math.log1p(-samplers.uniform(
shape=batch_shape,
dtype=log_accept_thresh.dtype,
seed=acceptance_seed))
is_sample_accepted = u <= log_accept_thresh
next_candidate_tree_state = TreeDoublingStateCandidate(
state=[
bu.where_left_justified_mask(is_sample_accepted, s0, s1)
for s0, s1 in zip(next_state_parts, candidate_tree_state.state)
],
target=bu.where_left_justified_mask(
is_sample_accepted, next_target, candidate_tree_state.target),
target_grad_parts=[
bu.where_left_justified_mask(is_sample_accepted, grad0, grad1)
for grad0, grad1 in zip(next_target_grad_parts,
candidate_tree_state.target_grad_parts)
],
energy=bu.where_left_justified_mask(
is_sample_accepted,
current_energy,
candidate_tree_state.energy),
weight=weight_sum)
continue_tree = not_divergent & continue_tree_previous
continue_tree_next = no_u_turns_within_tree & continue_tree
not_divergent_tokeep = tf.where(
continue_tree_previous,
not_divergent,
ps.ones(batch_shape, dtype=tf.bool))
# min(1., exp(energy_diff)).
exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
energy_diff_sum = tf.where(continue_tree,
energy_diff_sum_previous + exp_energy_diff,
energy_diff_sum_previous)
return (
iter_ + 1,
next_seed,
energy_diff_sum,
momentum_cumsum,
leapfrogs_taken,
next_tree_state,
next_candidate_tree_state,
continue_tree_next,
not_divergent_previous & not_divergent_tokeep,
velocity_state_memory,
)