in csrc/liars_dice/recursive_eval.cc [193:426]
int main(int argc, char* argv[]) {
int num_dice = 1;
int num_faces = 4;
int subgame_iters = 1024;
int mdp_depth = -1;
int num_repeats = -1;
std::string net_path;
bool repeat_oracle_net = false;
bool no_linear = false;
bool root_only = false;
bool print_regret = false;
bool print_regret_summary = false;
int eval_oracle_values_iters = -1;
int num_threads = 10;
SubgameSolvingParams base_params;
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
{
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "--num_dice") {
assert(i + 1 < argc);
num_dice = std::stoi(argv[++i]);
} else if (arg == "--num_faces") {
assert(i + 1 < argc);
num_faces = std::stoi(argv[++i]);
} else if (arg == "--subgame_iters") {
assert(i + 1 < argc);
subgame_iters = std::stoi(argv[++i]);
} else if (arg == "--mdp_depth") {
assert(i + 1 < argc);
mdp_depth = std::stoi(argv[++i]);
} else if (arg == "--num_threads") {
assert(i + 1 < argc);
num_threads = std::stoi(argv[++i]);
} else if (arg == "--num_repeats") {
assert(i + 1 < argc);
num_repeats = std::stoi(argv[++i]);
} else if (arg == "--root_only") {
root_only = true;
} else if (arg == "--repeat_oracle_net") {
repeat_oracle_net = true;
} else if (arg == "--net") {
assert(i + 1 < argc);
net_path = argv[++i];
} else if (arg == "--print_regret") {
print_regret = true;
} else if (arg == "--print_regret_summary") {
print_regret_summary = true;
} else if (arg == "--no_linear") {
no_linear = true;
} else if (arg == "--optimistic") {
base_params.optimistic = true;
} else if (arg == "--eval_oracle_values_iters") {
assert(i + 1 < argc);
eval_oracle_values_iters = std::stoi(argv[++i]);
} else if (arg == "--cfr") {
base_params.use_cfr = true;
} else if (arg == "--dcfr") {
base_params.dcfr = true;
base_params.dcfr_alpha = std::atof(argv[++i]);
base_params.dcfr_beta = std::atof(argv[++i]);
base_params.dcfr_gamma = std::atof(argv[++i]);
} else {
std::cerr << "Unknown flag: " << arg << "\n";
return -1;
}
}
}
assert(num_dice != -1);
assert(num_faces != -1);
const Game game(num_dice, num_faces);
std::cout << "num_dice=" << num_dice << " num_faces=" << num_faces << "\n";
const auto full_tree = unroll_tree(game);
std::cout << "Tree of depth " << get_depth(full_tree) << " has "
<< full_tree.size() << " nodes\n";
std::cout << "##############################################\n";
std::cout << "##### Solving the game for the full tree #####\n";
std::cout << "##############################################\n";
TreeStrategy full_strategy;
base_params.num_iters = subgame_iters;
base_params.linear_update = !no_linear && !base_params.dcfr;
{
SubgameSolvingParams params = base_params;
params.max_depth = 100000;
auto fp = build_solver(game, params);
std::vector<TreeStrategy> strategy_list;
for (int iter = 0; iter < subgame_iters; ++iter) {
fp->step(iter % 2);
if (iter % 2 == 0 && params.use_cfr) {
strategy_list.push_back(fp->get_sampling_strategy());
}
if (((iter + 1) & iter) == 0 || iter + 1 == subgame_iters) {
auto values = compute_exploitability2(game, fp->get_strategy());
printf("Iter=%8d exploitabilities=(%.3e, %.3e) sum=%.3e\n", iter + 1,
values[0], values[1], (values[0] + values[1]) / 2.);
}
}
full_strategy = fp->get_strategy();
auto explotabilities = compute_exploitability2(game, full_strategy);
std::cout << "Full FP exploitability: "
<< (explotabilities[0] + explotabilities[1]) / 2. << " ("
<< explotabilities[0] << "," << explotabilities[1] << ")"
<< std::endl;
// report_game_stats(game, fp->get_strategy());
if (!strategy_list.empty()) {
report_regrets(game, strategy_list, print_regret, print_regret_summary,
mdp_depth);
std::cout << "\n";
}
print_strategy(game, fp->get_tree(), fp->get_strategy(),
"strategy.full.txt");
}
std::vector<std::pair<std::string, TreeStrategy>> all_strategies;
all_strategies.emplace_back("full_tree", full_strategy);
if (!net_path.empty()) {
assert(mdp_depth > 0);
std::shared_ptr<IValueNet> net =
net_path == "zero"
? liars_dice::create_zero_net(game.num_hands(), false)
: liars_dice::create_torchscript_net(net_path);
std::cout << "##############################################\n";
std::cout << "##### Recursive solving #\n";
std::cout << "##############################################\n";
if (num_repeats > 0) {
torch::Tensor summed_stategy, summed_reach;
TreeStrategy final_strategy;
auto net_builder = [=]() {
if (repeat_oracle_net) {
SubgameSolvingParams oracle_net_params = base_params;
oracle_net_params.max_depth = 100000;
if (eval_oracle_values_iters > 0) {
oracle_net_params.num_iters = eval_oracle_values_iters;
}
return liars_dice::create_oracle_value_predictor(game,
oracle_net_params);
} else {
return liars_dice::create_torchscript_net(net_path, "cpu");
}
};
ParallerSampledStrategyComputor computer(game, net_builder, num_repeats,
base_params, mdp_depth,
num_threads, root_only);
std::vector<TreeStrategy> strategy_list;
for (int strategy_id = 0; strategy_id < num_repeats; ++strategy_id) {
{
auto [sampled_strategy_tensor, node_reach_tensor] =
computer.get(strategy_id);
if (strategy_id == 0) {
summed_stategy = sampled_strategy_tensor * node_reach_tensor;
summed_reach = node_reach_tensor;
} else {
summed_stategy += sampled_strategy_tensor * node_reach_tensor;
summed_reach += node_reach_tensor;
}
if (base_params.use_cfr) {
strategy_list.push_back(
tensor_to_tree_strategy(sampled_strategy_tensor));
}
}
final_strategy =
tensor_to_tree_strategy(summed_stategy / (summed_reach + 1e-6));
if (((strategy_id + 1) & strategy_id) == 0 ||
strategy_id + 1 == num_repeats) {
std::cout << std::setw(5) << strategy_id + 1 << ": ";
auto explotabilities = compute_exploitability2(game, final_strategy);
auto evs = compute_ev2(game, full_strategy, final_strategy);
std::cout << (explotabilities[0] + explotabilities[1]) / 2. << " ("
<< explotabilities[0] << "," << explotabilities[1] << ")"
<< "\tEV of full: ";
std::cout << (evs[0] + evs[1]) / 2 << " (" << evs[0] << "," << evs[1]
<< ")";
if (!strategy_list.empty()) {
report_regrets(game, strategy_list, print_regret,
print_regret_summary, mdp_depth);
}
std::cout << std::endl;
const auto name = (repeat_oracle_net ? "repeated oracle toleaf "
: "repeated toleaf ") +
std::to_string(strategy_id + 1);
all_strategies.emplace_back(name, final_strategy);
print_strategy(game, full_tree, final_strategy,
"strategy.repeated.txt");
}
}
}
}
// Reporting in human-readable format.
std::vector<std::pair<std::string, std::string>> result;
std::vector<std::pair<std::string, std::string>> result_ev;
result.emplace_back("net", net_path);
result_ev.emplace_back("net", net_path);
for (auto [name, mdp_strategy] : all_strategies) {
std::cout << " " << name << " ";
assert(mdp_strategy.size() == full_tree.size());
auto explotabilities = compute_exploitability2(game, mdp_strategy);
auto evs = compute_ev2(game, full_strategy, mdp_strategy);
std::cout << (explotabilities[0] + explotabilities[1]) / 2. << " ("
<< explotabilities[0] << "," << explotabilities[1] << ")"
<< "\n\tEV of full: ";
std::cout << (evs[0] + evs[1]) / 2. << " (" << evs[0] << "," << evs[1]
<< ")";
std::cout << std::endl;
result.emplace_back(
name, std::to_string((explotabilities[0] + explotabilities[1]) / 2.));
result_ev.emplace_back(name, std::to_string((evs[0] + evs[1]) / 2.));
}
// Reporting as JSON.
std::vector<
std::tuple<std::string, std::vector<std::pair<std::string, std::string>>>>
all_results = {{"XXX", result}, {"YYY", result_ev}};
for (auto [tag, dict] : all_results) {
std::cout << tag << " {";
bool first = true;
for (auto [k, v] : dict) {
if (first) {
first = false;
} else {
std::cout << ", ";
}
std::cout << '"' << k << "\":\"" << v << "\"";
}
std::cout << "}" << std::endl;
}
}