def _build_tree()

in pyro/infer/mcmc/nuts.py [0:0]


    def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current):
        if tree_depth == 0:
            return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current)

        # build the first half of tree
        half_tree = self._build_tree(z, r, z_grads, log_slice,
                                     direction, tree_depth-1, energy_current)
        z_proposal = half_tree.z_proposal
        z_proposal_pe = half_tree.z_proposal_pe
        z_proposal_grads = half_tree.z_proposal_grads

        # Check conditions to stop doubling. If we meet that condition,
        #     there is no need to build the other tree.
        if half_tree.turning or half_tree.diverging:
            return half_tree

        # Else, build remaining half of tree.
        # If we are going to the right, start from the right leaf of the first half.
        if direction == 1:
            z = half_tree.z_right
            r = half_tree.r_right
            z_grads = half_tree.z_right_grads
        else:  # otherwise, start from the left leaf of the first half
            z = half_tree.z_left
            r = half_tree.r_left
            z_grads = half_tree.z_left_grads
        other_half_tree = self._build_tree(z, r, z_grads, log_slice,
                                           direction, tree_depth-1, energy_current)

        if self.use_multinomial_sampling:
            tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight)
        else:
            tree_weight = half_tree.weight + other_half_tree.weight
        sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs
        num_proposals = half_tree.num_proposals + other_half_tree.num_proposals
        r_sum = {site_names: half_tree.r_sum[site_names] + other_half_tree.r_sum[site_names]
                 for site_names in self.inverse_mass_matrix}

        # The probability of that proposal belongs to which half of tree
        #     is computed based on the weights of each half.
        if self.use_multinomial_sampling:
            other_half_tree_prob = (other_half_tree.weight - tree_weight).exp()
        else:
            # For the special case that the weights of each half are both 0,
            #   we choose the proposal from the first half
            #   (any is fine, because the probability of picking it at the end is 0!).
            other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0
                                    else scalar_like(tree_weight, 0.))
        is_other_half_tree = pyro.sample("is_other_half_tree",
                                         dist.Bernoulli(probs=other_half_tree_prob))

        if is_other_half_tree == 1:
            z_proposal = other_half_tree.z_proposal
            z_proposal_pe = other_half_tree.z_proposal_pe
            z_proposal_grads = other_half_tree.z_proposal_grads

        # leaves of the full tree are determined by the direction
        if direction == 1:
            z_left = half_tree.z_left
            r_left = half_tree.r_left
            r_left_unscaled = half_tree.r_left_unscaled
            z_left_grads = half_tree.z_left_grads
            z_right = other_half_tree.z_right
            r_right = other_half_tree.r_right
            r_right_unscaled = other_half_tree.r_right_unscaled
            z_right_grads = other_half_tree.z_right_grads
        else:
            z_left = other_half_tree.z_left
            r_left = other_half_tree.r_left
            r_left_unscaled = other_half_tree.r_left_unscaled
            z_left_grads = other_half_tree.z_left_grads
            z_right = half_tree.z_right
            r_right = half_tree.r_right
            r_right_unscaled = half_tree.r_right_unscaled
            z_right_grads = half_tree.z_right_grads

        # We already check if first half tree is turning. Now, we check
        #     if the other half tree or full tree are turning.
        turning = other_half_tree.turning or self._is_turning(r_left_unscaled, r_right_unscaled, r_sum)

        # The divergence is checked by the second half tree (the first half is already checked).
        diverging = other_half_tree.diverging

        return _TreeInfo(z_left, r_left, r_left_unscaled, z_left_grads, z_right, r_right, r_right_unscaled,
                         z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight,
                         turning, diverging, sum_accept_probs, num_proposals)