in tensorflow_probability/python/experimental/mcmc/preconditioned_nuts.py [0:0]
def _loop_tree_doubling(self, step_size, velocity_state_memory,
current_step_meta_info, iter_, initial_step_state,
initial_step_metastate, momentum_distribution, seed,
shard_axis_names=None):
"""Main loop for tree doubling."""
with tf.name_scope('loop_tree_doubling'):
(direction_seed,
subtree_seed,
acceptance_seed,
next_seed) = samplers.split_seed(seed, n=4)
batch_shape = ps.shape(current_step_meta_info.init_energy)
direction = tf.cast(
samplers.uniform(
shape=batch_shape,
minval=0,
maxval=2,
dtype=tf.int32,
seed=direction_seed),
dtype=tf.bool)
tree_start_states = tf.nest.map_structure(
lambda v: bu.where_left_justified_mask(direction, v[1], v[0]),
initial_step_state)
directions_expanded = [
bu.left_justified_expand_dims_like(direction, state)
for state in tree_start_states.state
]
integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
self.target_log_prob_fn,
step_sizes=[
tf.where(d, ss, -ss)
for d, ss in zip(directions_expanded, step_size)
],
num_steps=self.unrolled_leapfrog_steps)
[
candidate_tree_state,
tree_final_states,
final_not_divergence,
continue_tree_final,
energy_diff_tree_sum,
momentum_subtree_cumsum,
leapfrogs_taken
] = self._build_sub_tree(
directions_expanded,
integrator,
current_step_meta_info,
# num_steps_at_this_depth = 2**iter_ = 1 << iter_
tf.bitwise.left_shift(1, iter_),
tree_start_states,
initial_step_metastate.continue_tree,
initial_step_metastate.not_divergence,
velocity_state_memory,
momentum_distribution,
seed=subtree_seed)
last_candidate_state = initial_step_metastate.candidate_state
energy_diff_sum = (
energy_diff_tree_sum + initial_step_metastate.energy_diff_sum)
if MULTINOMIAL_SAMPLE:
tree_weight = tf.where(
continue_tree_final,
candidate_tree_state.weight,
tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype))
weight_sum = log_add_exp(tree_weight, last_candidate_state.weight)
log_accept_thresh = tree_weight - last_candidate_state.weight
else:
tree_weight = tf.where(
continue_tree_final,
candidate_tree_state.weight,
tf.zeros([], dtype=TREE_COUNT_DTYPE))
weight_sum = tree_weight + last_candidate_state.weight
log_accept_thresh = tf.math.log(
tf.cast(tree_weight, tf.float32) /
tf.cast(last_candidate_state.weight, tf.float32))
log_accept_thresh = tf.where(
tf.math.is_nan(log_accept_thresh),
tf.zeros([], log_accept_thresh.dtype),
log_accept_thresh)
u = tf.math.log1p(-samplers.uniform(
shape=batch_shape,
dtype=log_accept_thresh.dtype,
seed=acceptance_seed))
is_sample_accepted = u <= log_accept_thresh
choose_new_state = is_sample_accepted & continue_tree_final
new_candidate_state = TreeDoublingStateCandidate(
state=[
bu.where_left_justified_mask(choose_new_state, s0, s1)
for s0, s1 in zip(candidate_tree_state.state,
last_candidate_state.state)
],
target=bu.where_left_justified_mask(
choose_new_state,
candidate_tree_state.target,
last_candidate_state.target),
target_grad_parts=[
bu.where_left_justified_mask(choose_new_state, grad0, grad1)
for grad0, grad1 in zip(candidate_tree_state.target_grad_parts,
last_candidate_state.target_grad_parts)
],
energy=bu.where_left_justified_mask(
choose_new_state,
candidate_tree_state.energy,
last_candidate_state.energy),
weight=weight_sum)
for new_candidate_state_temp, old_candidate_state_temp in zip(
new_candidate_state.state, last_candidate_state.state):
tensorshape_util.set_shape(new_candidate_state_temp,
old_candidate_state_temp.shape)
for new_candidate_grad_temp, old_candidate_grad_temp in zip(
new_candidate_state.target_grad_parts,
last_candidate_state.target_grad_parts):
tensorshape_util.set_shape(new_candidate_grad_temp,
old_candidate_grad_temp.shape)
# Update left right information of the trajectory, and check trajectory
# level U turn
tree_otherend_states = tf.nest.map_structure(
lambda v: bu.where_left_justified_mask(direction, v[0], v[1]),
initial_step_state)
new_step_state = tf.nest.pack_sequence_as(initial_step_state, [
tf.stack([ # pylint: disable=g-complex-comprehension
bu.where_left_justified_mask(direction, right, left),
bu.where_left_justified_mask(direction, left, right),
], axis=0)
for left, right in zip(tf.nest.flatten(tree_final_states),
tf.nest.flatten(tree_otherend_states))
])
momentum_tree_cumsum = []
for p0, p1 in zip(
initial_step_metastate.momentum_sum, momentum_subtree_cumsum):
momentum_part_temp = p0 + p1
tensorshape_util.set_shape(momentum_part_temp, p0.shape)
momentum_tree_cumsum.append(momentum_part_temp)
for new_state_temp, old_state_temp in zip(
tf.nest.flatten(new_step_state),
tf.nest.flatten(initial_step_state)):
tensorshape_util.set_shape(new_state_temp, old_state_temp.shape)
if GENERALIZED_UTURN:
state_diff = momentum_tree_cumsum
else:
state_diff = [s[1] - s[0] for s in new_step_state.state]
no_u_turns_trajectory = has_not_u_turn(
state_diff,
[m[0] for m in new_step_state.velocity],
[m[1] for m in new_step_state.velocity],
log_prob_rank=ps.rank_from_shape(batch_shape),
shard_axis_names=self.experimental_shard_axis_names)
new_step_metastate = TreeDoublingMetaState(
candidate_state=new_candidate_state,
is_accepted=choose_new_state | initial_step_metastate.is_accepted,
momentum_sum=momentum_tree_cumsum,
energy_diff_sum=energy_diff_sum,
continue_tree=continue_tree_final & no_u_turns_trajectory,
not_divergence=final_not_divergence,
leapfrog_count=(initial_step_metastate.leapfrog_count +
leapfrogs_taken))
return iter_ + 1, next_seed, new_step_state, new_step_metastate