NutsProposer::Tree NutsProposer::build_tree()

in src/beanmachine/graph/global/proposer/nuts_proposer.cpp [174:253]


NutsProposer::Tree NutsProposer::build_tree(
    GlobalState& state,
    std::mt19937& gen,
    Eigen::VectorXd position,
    Eigen::VectorXd momentum,
    double slice,
    double direction,
    int tree_depth,
    double hamiltonian_init) {
  if (tree_depth == 0) {
    return build_tree_base_case(
        state, position, momentum, slice, direction, hamiltonian_init);
  } else {
    Tree subtree1 = build_tree(
        state,
        gen,
        position,
        momentum,
        slice,
        direction,
        tree_depth - 1,
        hamiltonian_init);
    if (!subtree1.no_turn) {
      return subtree1;
    } else {
      Tree tree = Tree();
      tree.position_new = subtree1.position_new;

      Tree subtree2;
      if (direction < 0) {
        subtree2 = build_tree(
            state,
            gen,
            subtree1.position_left,
            subtree1.momentum_left,
            slice,
            direction,
            tree_depth - 1,
            hamiltonian_init);
        tree.position_left = subtree2.position_left;
        tree.momentum_left = subtree2.momentum_left;
        tree.position_right = subtree1.position_right;
        tree.momentum_right = subtree1.momentum_right;
      } else {
        subtree2 = build_tree(
            state,
            gen,
            subtree1.position_right,
            subtree1.momentum_right,
            slice,
            direction,
            tree_depth - 1,
            hamiltonian_init);
        tree.position_left = subtree1.position_left;
        tree.momentum_left = subtree1.momentum_left;
        tree.position_right = subtree2.position_right;
        tree.momentum_right = subtree2.momentum_right;
      }

      double update_prob = subtree2.valid_nodes /
          std::max(1.0, subtree1.valid_nodes + subtree2.valid_nodes);
      std::bernoulli_distribution update_dist(update_prob);
      if (update_dist(gen)) {
        tree.position_new = subtree2.position_new;
      }

      tree.acceptance_sum = subtree1.acceptance_sum + subtree2.acceptance_sum;
      tree.total_nodes = subtree1.total_nodes + subtree2.total_nodes;
      tree.no_turn = subtree2.no_turn and
          compute_no_turn(
                         tree.position_left,
                         tree.momentum_left,
                         tree.position_right,
                         tree.momentum_right);
      tree.valid_nodes = subtree1.valid_nodes + subtree2.valid_nodes;

      return tree;
    }
  }
}