in csrc/liars_dice/recursive_solving.cc [192:246]
void RlRunner::sample_state_to_leaf(const ISubgameSolver* solver) {
const auto& tree = solver->get_tree();
// List of (node, action) pairs.
std::vector<std::pair<int, Action>> path;
{
int node_id = 0;
const auto br_sampler = std::uniform_int_distribution<>(0, 1)(gen_);
const auto& strategy = solver->get_sampling_strategy();
auto sampling_beliefs = beliefs_;
while (tree[node_id].num_children()) {
const auto eps = std::uniform_real_distribution<float>(0, 1)(gen_);
Action action;
const auto& state = tree[node_id].state;
const auto [action_begin, action_end] = game_.get_bid_range(state);
if (state.player_id == br_sampler && eps < random_action_prob_) {
std::uniform_int_distribution<> dis(action_begin, action_end - 1);
action = dis(gen_);
} else {
const auto& beliefs = sampling_beliefs[state.player_id];
std::discrete_distribution<> dis(beliefs.begin(), beliefs.end());
const int hand = dis(gen_);
const std::vector<double>& policy = strategy[node_id][hand];
std::discrete_distribution<> action_dis(policy.begin(), policy.end());
action = action_dis(gen_);
assert(action >= action_begin && action < action_end);
}
// Update beliefs.
// Policy[hand, action] := P(action | hand).
const auto& policy = strategy[node_id];
// P^{t+1}(hand|action) \propto P^t(action|hand)P^t(hand) .
for (int hand = 0; hand < game_.num_hands(); ++hand) {
// Assuming that the policy has zeros outside of the range.
sampling_beliefs[state.player_id][hand] *= policy[hand][action];
}
normalize_beliefs_inplace(sampling_beliefs[state.player_id]);
path.emplace_back(node_id, action);
node_id = tree[node_id].children_begin + action - action_begin;
}
}
// We do another pass over the path to compute beliefs accroding to
// `get_belief_propogation_strategy` that could differ from the sampling
// strategy.
for (auto [node_id, action] : path) {
const auto action_begin = game_.get_bid_range(state_).first;
const auto& policy = solver->get_belief_propogation_strategy()[node_id];
for (int hand = 0; hand < game_.num_hands(); ++hand) {
// Assuming that the policy has zeros outside of the range.
beliefs_[state_.player_id][hand] *= policy[hand][action];
}
normalize_beliefs_inplace(beliefs_[state_.player_id]);
int child_node_id = tree[node_id].children_begin + action - action_begin;
state_ = tree[child_node_id].state;
}
}