in simple_game/main.cpp [28:314]
int main(int argc, char* argv[]) {
cxxopts::Options cmdOptions(
"Tabular Game Solver", "Simple Game Solver for Tabular Games");
cmdOptions.add_options()(
"g,game",
"Game Name",
cxxopts::value<std::string>()->default_value("comm"))(
"method",
"Name of method",
cxxopts::value<std::string>()->default_value("search"))(
"load_pi", "Load policy", cxxopts::value<std::string>())(
"load_pi_log", "Load policy from log", cxxopts::value<std::string>())(
"num_round",
"#round for comm",
cxxopts::value<int>()->default_value("4"))(
"num_card", "#card for comm", cxxopts::value<int>()->default_value("-1"))(
"seed", "Random Seed", cxxopts::value<int>()->default_value("1"))(
"first_random_infoset",
"Random Seed",
cxxopts::value<std::string>()->default_value(""))(
"gt_compute",
"Compute exhaustive search",
cxxopts::value<bool>()->default_value("false"))(
"gt_override",
"Override research with gt result",
cxxopts::value<bool>()->default_value("false"))(
"perturb_chance",
"Whether we purturb chance node",
cxxopts::value<float>()->default_value("0.0"))(
"perturb_policy",
"Whether we purturb policy",
cxxopts::value<float>()->default_value("0.0"))(
"verbose", "Verbose level", cxxopts::value<int>()->default_value("1"))(
"compute_reach",
"Whether we compute reach",
cxxopts::value<bool>()->default_value("false"))(
"no_opt",
"Not compare with optimal strategy",
cxxopts::value<bool>()->default_value("false"))(
"iter", "#Iteration", cxxopts::value<int>()->default_value("100"))(
"iter_cfr",
"#Iteration for cfr",
cxxopts::value<int>()->default_value("1000"))(
"no_cfr_init",
"Do not use CFR for initialization",
cxxopts::value<bool>()->default_value("false"))(
"use_cfr_pure_init",
"Use CFR pure strategy to init",
cxxopts::value<bool>()->default_value("false"))(
"dump_json",
"Save strategy to JSON.",
cxxopts::value<bool>()->default_value("false"))(
"print_strategy_before_search",
"Print Strategy Before Search",
cxxopts::value<bool>()->default_value("false"))(
"show_better",
"whether show better policy when there is improvement",
cxxopts::value<bool>()->default_value("false"))(
"N,N_minibridge",
"N in MiniBridge",
cxxopts::value<int>()->default_value("3"))(
"use_2nd_order", "", cxxopts::value<bool>()->default_value("false"))(
"max_depth",
"Max optimization depth (0 mean till the end)",
cxxopts::value<int>()->default_value("0"))(
"skip_single_infoset_opt",
"",
cxxopts::value<bool>()->default_value("false"))(
"skip_same_delta_policy",
"",
cxxopts::value<bool>()->default_value("false"))(
"num_samples",
"#samples per infoset used in each iteration, 0 = use all",
cxxopts::value<int>()->default_value("0"))(
"num_samples_total",
"#total number of samples across all infoset. 0 = not used",
cxxopts::value<int>()->default_value("0"));
std::cout << "Command line: ";
for (int i = 0; i < argc; ++i) {
std::cout << argv[i] << " ";
}
std::cout << std::endl;
auto result = cmdOptions.parse(argc, argv);
for (const auto& kv : result.arguments()) {
std::cout << kv.key() << ": " << kv.value() << std::endl;
}
std::string gameName = result["game"].as<std::string>();
tabular::Options options;
options.seed = result["seed"].as<int>();
options.method = result["method"].as<std::string>();
options.verbose =
static_cast<tabular::VerboseLevel>(result["verbose"].as<int>());
options.perturbChance = result["perturb_chance"].as<float>();
options.perturbPolicy = result["perturb_policy"].as<float>();
options.firstRandomInfoSetKey =
result["first_random_infoset"].as<std::string>();
options.gtCompute = result["gt_compute"].as<bool>();
options.gtOverride = result["gt_override"].as<bool>();
options.computeReach = result["compute_reach"].as<bool>();
options.use2ndOrder = result["use_2nd_order"].as<bool>();
options.maxDepth = result["max_depth"].as<int>();
options.showBetter = result["show_better"].as<bool>();
options.skipSingleInfoSetOpt = result["skip_single_infoset_opt"].as<bool>();
options.skipSameDeltaPolicy = result["skip_same_delta_policy"].as<bool>();
options.numSample = result["num_samples"].as<int>();
options.numSampleTotal = result["num_samples_total"].as<int>();
int numIter = result["iter"].as<int>();
int numIterCFR = result["iter_cfr"].as<int>();
bool noCFRInit = result["no_cfr_init"].as<bool>();
bool useCFRPureInit = result["use_cfr_pure_init"].as<bool>();
bool printStrategyBeforeSearch =
result["print_strategy_before_search"].as<bool>();
bool dumpJson = result["dump_json"].as<bool>();
simple::CommOptions gameOptions;
gameOptions.numRound = result["num_round"].as<int>();
gameOptions.possibleCards = result["num_card"].as<int>();
gameOptions.N = result["N"].as<int>();
std::string method = result["method"].as<std::string>();
auto start = high_resolution_clock::now();
std::unique_ptr<rela::Env> game;
std::unique_ptr<rela::OptimalStrategy> strategy;
if (gameName == "kuhn") {
game = std::make_unique<simple::KuhnPoker>();
} else if (gameName == "comm") {
game = std::make_unique<simple::Communicate>(gameOptions);
strategy = std::make_unique<simple::CommunicatePolicy>(gameOptions);
} else if (gameName == "comm2") {
game = std::make_unique<simple::Communicate2>(gameOptions);
strategy = std::make_unique<simple::Communicate2Policy>();
} else if (gameName == "simplebidding") {
game = std::make_unique<simple::SimpleBidding>(gameOptions);
} else if (gameName == "2suitedbridge") {
game = std::make_unique<simple::TwoSuitedBridge>(gameOptions);
} else if (gameName == "simplehanabi") {
simple::hanabi::Options hanabiOptions;
game = std::make_unique<simple::hanabi::SimpleHanabi>(hanabiOptions);
} else {
throw std::runtime_error(gameName + " is not implemented");
}
game->reset();
/*
tabular_cfr::CFR cfr(options);
std::cout << "Initialize search tree.. " << std::endl;
cfr.init(*game);
std::cout << "Initialize done. #infoSet: " << cfr.infoSets().numInfoSets()
<< ", #node: " << cfr.infoSets().numNodes() << std::endl;
cfr.infoSets().printInfoSetTree();
*/
tabular::Policies policies;
std::vector<float> v;
std::vector<float> vPure;
std::vector<float> vLoaded;
if (!noCFRInit) {
tabular::cfr::CFRSolver cfrSolver(
options.seed, options.verbose == tabular::VerboseLevel::VERBOSE);
std::cout << "Initialize CFR search tree" << std::endl;
cfrSolver.init(*game);
std::cout << "Run CFR for " << numIterCFR << " iterations." << std::endl;
v = cfrSolver.run(numIterCFR);
std::cout << "Result after CFR " << numIterCFR
<< " iterations with seed: " << options.seed << std::endl;
for (int i = 1; i < (int)v.size(); ++i) {
std::cout << "CFR Player " << i << " expected value: " << v[i]
<< std::endl;
}
// Also get the value after purification.
if (useCFRPureInit) {
cfrSolver.getInfos().purifyStrategies();
policies = cfrSolver.getInfos().getStrategies();
} else {
policies = cfrSolver.getInfos().getStrategies();
cfrSolver.getInfos().purifyStrategies();
}
vPure = cfrSolver.evaluate();
}
if (options.method == "cfr") {
json j;
j["CFR"] = v[1];
j["CFRPure"] = vPure[1];
std::cout << "json_str: " << j.dump() << std::endl;
return 0;
}
tabular::search::Solver solver(options);
std::cout << "Initialize search tree.. " << std::endl;
solver.init(*game);
std::cout << "Initialize done. #infoSet: " << solver.manager().numInfoSets()
<< ", #states: " << solver.manager().numStates() << std::endl;
if (options.verbose == tabular::VerboseLevel::VERBOSE) {
solver.manager().printInfoSetTree();
}
if (result["load_pi"].count()) {
auto filename = result["load_pi"].as<std::string>();
std::cout << "Loading pi: " << filename << std::endl;
solver.loadPolicies(filename);
solver.evaluate();
vLoaded = solver.u();
} else if (result["load_pi_log"].count()) {
auto filename = result["load_pi_log"].as<std::string>();
std::cout << "Loading pi_log: " << filename << std::endl;
policies = tabular_utils::loadPolicy(*game, filename);
solver.manager().randomizePolicy();
solver.loadPolicies(policies);
solver.evaluate();
vLoaded = solver.u();
} else {
if (noCFRInit) {
solver.manager().randomizePolicy();
} else {
solver.loadPolicies(policies);
}
}
if (printStrategyBeforeSearch) {
solver.evaluate();
solver.manager().printStrategy();
}
const int playerIdx = 1;
auto sampler = tabular::search::InfoSetsSampler(solver.manager());
auto searchResult = solver.runSearch(1, numIter, sampler);
json j;
if (!v.empty()) j["CFR"] = v[playerIdx];
if (!vPure.empty()) j["CFRPure"] = vPure[playerIdx];
if (!vLoaded.empty()) j["Loaded"] = vLoaded[playerIdx];
j["Search"] = searchResult.bestSoFar;
std::cout << "json_str: " << j.dump() << std::endl;
solver.manager().printStrategy();
if (dumpJson) {
std::cout << std::endl
<< "json_str: " << solver.manager().strategyJson() << std::endl;
}
/*
std::cout << "Improving strategy with joint search: " << std::endl;
std::cout << solver.enumPolicies(1);
*/
auto stop = high_resolution_clock::now();
std::cout << "Time spent: "
<< duration_cast<microseconds>(stop - start).count() / 1e6 << "s"
<< std::endl;
if (options.verbose == tabular::VerboseLevel::VERBOSE) {
std::cout << solver.printTree() << std::endl;
}
//
if (strategy != nullptr && !result["no_opt"].as<bool>()) {
auto strategies = [&](const std::string& key) {
return strategy->getOptimalStrategy(key);
};
std::cout << "Optimal strategy: " << std::endl;
solver.manager().setStrategies(strategies);
std::cout << "Evaluating.. " << std::endl;
solver.evaluate();
const auto& v = solver.root()->u();
for (int i = 1; i < (int)v.size(); ++i) {
std::cout << "Player " << i << " optimal value: " << v[i] << std::endl;
}
}
return 0;
}