in pyro/infer/mcmc/nuts.py [0:0]
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current):
if tree_depth == 0:
return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current)
# build the first half of tree
half_tree = self._build_tree(z, r, z_grads, log_slice,
direction, tree_depth-1, energy_current)
z_proposal = half_tree.z_proposal
z_proposal_pe = half_tree.z_proposal_pe
z_proposal_grads = half_tree.z_proposal_grads
# Check conditions to stop doubling. If we meet that condition,
# there is no need to build the other tree.
if half_tree.turning or half_tree.diverging:
return half_tree
# Else, build remaining half of tree.
# If we are going to the right, start from the right leaf of the first half.
if direction == 1:
z = half_tree.z_right
r = half_tree.r_right
z_grads = half_tree.z_right_grads
else: # otherwise, start from the left leaf of the first half
z = half_tree.z_left
r = half_tree.r_left
z_grads = half_tree.z_left_grads
other_half_tree = self._build_tree(z, r, z_grads, log_slice,
direction, tree_depth-1, energy_current)
if self.use_multinomial_sampling:
tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight)
else:
tree_weight = half_tree.weight + other_half_tree.weight
sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs
num_proposals = half_tree.num_proposals + other_half_tree.num_proposals
r_sum = {site_names: half_tree.r_sum[site_names] + other_half_tree.r_sum[site_names]
for site_names in self.inverse_mass_matrix}
# The probability of that proposal belongs to which half of tree
# is computed based on the weights of each half.
if self.use_multinomial_sampling:
other_half_tree_prob = (other_half_tree.weight - tree_weight).exp()
else:
# For the special case that the weights of each half are both 0,
# we choose the proposal from the first half
# (any is fine, because the probability of picking it at the end is 0!).
other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0
else scalar_like(tree_weight, 0.))
is_other_half_tree = pyro.sample("is_other_half_tree",
dist.Bernoulli(probs=other_half_tree_prob))
if is_other_half_tree == 1:
z_proposal = other_half_tree.z_proposal
z_proposal_pe = other_half_tree.z_proposal_pe
z_proposal_grads = other_half_tree.z_proposal_grads
# leaves of the full tree are determined by the direction
if direction == 1:
z_left = half_tree.z_left
r_left = half_tree.r_left
r_left_unscaled = half_tree.r_left_unscaled
z_left_grads = half_tree.z_left_grads
z_right = other_half_tree.z_right
r_right = other_half_tree.r_right
r_right_unscaled = other_half_tree.r_right_unscaled
z_right_grads = other_half_tree.z_right_grads
else:
z_left = other_half_tree.z_left
r_left = other_half_tree.r_left
r_left_unscaled = other_half_tree.r_left_unscaled
z_left_grads = other_half_tree.z_left_grads
z_right = half_tree.z_right
r_right = half_tree.r_right
r_right_unscaled = half_tree.r_right_unscaled
z_right_grads = half_tree.z_right_grads
# We already check if first half tree is turning. Now, we check
# if the other half tree or full tree are turning.
turning = other_half_tree.turning or self._is_turning(r_left_unscaled, r_right_unscaled, r_sum)
# The divergence is checked by the second half tree (the first half is already checked).
diverging = other_half_tree.diverging
return _TreeInfo(z_left, r_left, r_left_unscaled, z_left_grads, z_right, r_right, r_right_unscaled,
z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight,
turning, diverging, sum_accept_probs, num_proposals)