in src/beanmachine/graph/global/proposer/nuts_proposer.cpp [256:338]
double NutsProposer::propose(GlobalState& state, std::mt19937& gen) {
Eigen::VectorXd position;
state.get_flattened_unconstrained_values(position);
// sample momentum
Eigen::VectorXd momentum_init = initialize_momentum(position, gen);
// sample slice
std::uniform_real_distribution<double> uniform_dist(0.0, 1.0);
double hamiltonian_init = compute_hamiltonian(state, position, momentum_init);
double slice = std::log(uniform_dist(gen)) - hamiltonian_init;
Eigen::VectorXd position_left = position;
Eigen::VectorXd position_right = position;
Eigen::VectorXd momentum_left = momentum_init;
Eigen::VectorXd momentum_right = momentum_init;
double valid_nodes = 1;
double acceptance_sum = 0.0;
double total_nodes = 0.0;
std::bernoulli_distribution coin_flip(0.5);
for (int tree_depth = 0; tree_depth < max_tree_depth; tree_depth++) {
// sample direction
double direction = -1.0;
if (coin_flip(gen)) {
direction = 1.0;
}
Tree tree;
if (direction < 0) {
tree = build_tree(
state,
gen,
position_left,
momentum_left,
slice,
direction,
tree_depth,
hamiltonian_init);
position_left = tree.position_left;
momentum_left = tree.momentum_left;
} else {
tree = build_tree(
state,
gen,
position_right,
momentum_right,
slice,
direction,
tree_depth,
hamiltonian_init);
position_right = tree.position_right;
momentum_right = tree.momentum_right;
}
acceptance_sum = tree.acceptance_sum;
total_nodes = tree.total_nodes;
if (!tree.no_turn) {
break;
}
double update_prob = std::min(1.0, tree.valid_nodes / valid_nodes);
std::bernoulli_distribution update_dist(update_prob);
if (update_dist(gen)) {
position = tree.position_new;
}
valid_nodes += tree.valid_nodes;
bool no_turn = compute_no_turn(
position_left, momentum_left, position_right, momentum_right);
if (!no_turn) {
break;
}
}
warmup_acceptance_prob = acceptance_sum / total_nodes;
state.set_flattened_unconstrained_values(position);
state.update_log_prob();
return 0.0;
}