double NutsProposer::propose()

in src/beanmachine/graph/global/proposer/nuts_proposer.cpp [256:338]


double NutsProposer::propose(GlobalState& state, std::mt19937& gen) {
  Eigen::VectorXd position;
  state.get_flattened_unconstrained_values(position);

  // sample momentum
  Eigen::VectorXd momentum_init = initialize_momentum(position, gen);
  // sample slice
  std::uniform_real_distribution<double> uniform_dist(0.0, 1.0);
  double hamiltonian_init = compute_hamiltonian(state, position, momentum_init);
  double slice = std::log(uniform_dist(gen)) - hamiltonian_init;

  Eigen::VectorXd position_left = position;
  Eigen::VectorXd position_right = position;
  Eigen::VectorXd momentum_left = momentum_init;
  Eigen::VectorXd momentum_right = momentum_init;

  double valid_nodes = 1;
  double acceptance_sum = 0.0;
  double total_nodes = 0.0;

  std::bernoulli_distribution coin_flip(0.5);

  for (int tree_depth = 0; tree_depth < max_tree_depth; tree_depth++) {
    // sample direction
    double direction = -1.0;
    if (coin_flip(gen)) {
      direction = 1.0;
    }

    Tree tree;
    if (direction < 0) {
      tree = build_tree(
          state,
          gen,
          position_left,
          momentum_left,
          slice,
          direction,
          tree_depth,
          hamiltonian_init);
      position_left = tree.position_left;
      momentum_left = tree.momentum_left;
    } else {
      tree = build_tree(
          state,
          gen,
          position_right,
          momentum_right,
          slice,
          direction,
          tree_depth,
          hamiltonian_init);
      position_right = tree.position_right;
      momentum_right = tree.momentum_right;
    }

    acceptance_sum = tree.acceptance_sum;
    total_nodes = tree.total_nodes;
    if (!tree.no_turn) {
      break;
    }

    double update_prob = std::min(1.0, tree.valid_nodes / valid_nodes);
    std::bernoulli_distribution update_dist(update_prob);
    if (update_dist(gen)) {
      position = tree.position_new;
    }
    valid_nodes += tree.valid_nodes;

    bool no_turn = compute_no_turn(
        position_left, momentum_left, position_right, momentum_right);
    if (!no_turn) {
      break;
    }
  }

  warmup_acceptance_prob = acceptance_sum / total_nodes;

  state.set_flattened_unconstrained_values(position);
  state.update_log_prob();

  return 0.0;
}