in src/beanmachine/graph/global/proposer/nuts_proposer.cpp [174:253]
NutsProposer::Tree NutsProposer::build_tree(
GlobalState& state,
std::mt19937& gen,
Eigen::VectorXd position,
Eigen::VectorXd momentum,
double slice,
double direction,
int tree_depth,
double hamiltonian_init) {
if (tree_depth == 0) {
return build_tree_base_case(
state, position, momentum, slice, direction, hamiltonian_init);
} else {
Tree subtree1 = build_tree(
state,
gen,
position,
momentum,
slice,
direction,
tree_depth - 1,
hamiltonian_init);
if (!subtree1.no_turn) {
return subtree1;
} else {
Tree tree = Tree();
tree.position_new = subtree1.position_new;
Tree subtree2;
if (direction < 0) {
subtree2 = build_tree(
state,
gen,
subtree1.position_left,
subtree1.momentum_left,
slice,
direction,
tree_depth - 1,
hamiltonian_init);
tree.position_left = subtree2.position_left;
tree.momentum_left = subtree2.momentum_left;
tree.position_right = subtree1.position_right;
tree.momentum_right = subtree1.momentum_right;
} else {
subtree2 = build_tree(
state,
gen,
subtree1.position_right,
subtree1.momentum_right,
slice,
direction,
tree_depth - 1,
hamiltonian_init);
tree.position_left = subtree1.position_left;
tree.momentum_left = subtree1.momentum_left;
tree.position_right = subtree2.position_right;
tree.momentum_right = subtree2.momentum_right;
}
double update_prob = subtree2.valid_nodes /
std::max(1.0, subtree1.valid_nodes + subtree2.valid_nodes);
std::bernoulli_distribution update_dist(update_prob);
if (update_dist(gen)) {
tree.position_new = subtree2.position_new;
}
tree.acceptance_sum = subtree1.acceptance_sum + subtree2.acceptance_sum;
tree.total_nodes = subtree1.total_nodes + subtree2.total_nodes;
tree.no_turn = subtree2.no_turn and
compute_no_turn(
tree.position_left,
tree.momentum_left,
tree.position_right,
tree.momentum_right);
tree.valid_nodes = subtree1.valid_nodes + subtree2.valid_nodes;
return tree;
}
}
}