void RlRunner::sample_state_to_leaf()

in csrc/liars_dice/recursive_solving.cc [192:246]


void RlRunner::sample_state_to_leaf(const ISubgameSolver* solver) {
  const auto& tree = solver->get_tree();
  // List of (node, action) pairs.
  std::vector<std::pair<int, Action>> path;
  {
    int node_id = 0;
    const auto br_sampler = std::uniform_int_distribution<>(0, 1)(gen_);
    const auto& strategy = solver->get_sampling_strategy();
    auto sampling_beliefs = beliefs_;
    while (tree[node_id].num_children()) {
      const auto eps = std::uniform_real_distribution<float>(0, 1)(gen_);
      Action action;
      const auto& state = tree[node_id].state;
      const auto [action_begin, action_end] = game_.get_bid_range(state);
      if (state.player_id == br_sampler && eps < random_action_prob_) {
        std::uniform_int_distribution<> dis(action_begin, action_end - 1);
        action = dis(gen_);
      } else {
        const auto& beliefs = sampling_beliefs[state.player_id];
        std::discrete_distribution<> dis(beliefs.begin(), beliefs.end());
        const int hand = dis(gen_);
        const std::vector<double>& policy = strategy[node_id][hand];
        std::discrete_distribution<> action_dis(policy.begin(), policy.end());
        action = action_dis(gen_);
        assert(action >= action_begin && action < action_end);
      }
      // Update beliefs.
      // Policy[hand, action] := P(action | hand).
      const auto& policy = strategy[node_id];
      // P^{t+1}(hand|action) \propto  P^t(action|hand)P^t(hand) .
      for (int hand = 0; hand < game_.num_hands(); ++hand) {
        // Assuming that the policy has zeros outside of the range.
        sampling_beliefs[state.player_id][hand] *= policy[hand][action];
      }
      normalize_beliefs_inplace(sampling_beliefs[state.player_id]);
      path.emplace_back(node_id, action);
      node_id = tree[node_id].children_begin + action - action_begin;
    }
  }

  // We do another pass over the path to compute beliefs accroding to
  // `get_belief_propogation_strategy` that could differ from the sampling
  // strategy.
  for (auto [node_id, action] : path) {
    const auto action_begin = game_.get_bid_range(state_).first;
    const auto& policy = solver->get_belief_propogation_strategy()[node_id];
    for (int hand = 0; hand < game_.num_hands(); ++hand) {
      // Assuming that the policy has zeros outside of the range.
      beliefs_[state_.player_id][hand] *= policy[hand][action];
    }
    normalize_beliefs_inplace(beliefs_[state_.player_id]);
    int child_node_id = tree[node_id].children_begin + action - action_begin;
    state_ = tree[child_node_id].state;
  }
}