int main()

in csrc/liars_dice/recursive_eval.cc [193:426]


int main(int argc, char* argv[]) {
  int num_dice = 1;
  int num_faces = 4;
  int subgame_iters = 1024;
  int mdp_depth = -1;
  int num_repeats = -1;
  std::string net_path;
  bool repeat_oracle_net = false;
  bool no_linear = false;
  bool root_only = false;
  bool print_regret = false;
  bool print_regret_summary = false;
  int eval_oracle_values_iters = -1;
  int num_threads = 10;
  SubgameSolvingParams base_params;
  std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
  {
    for (int i = 1; i < argc; i++) {
      std::string arg = argv[i];
      if (arg == "--num_dice") {
        assert(i + 1 < argc);
        num_dice = std::stoi(argv[++i]);
      } else if (arg == "--num_faces") {
        assert(i + 1 < argc);
        num_faces = std::stoi(argv[++i]);
      } else if (arg == "--subgame_iters") {
        assert(i + 1 < argc);
        subgame_iters = std::stoi(argv[++i]);
      } else if (arg == "--mdp_depth") {
        assert(i + 1 < argc);
        mdp_depth = std::stoi(argv[++i]);
      } else if (arg == "--num_threads") {
        assert(i + 1 < argc);
        num_threads = std::stoi(argv[++i]);
      } else if (arg == "--num_repeats") {
        assert(i + 1 < argc);
        num_repeats = std::stoi(argv[++i]);
      } else if (arg == "--root_only") {
        root_only = true;
      } else if (arg == "--repeat_oracle_net") {
        repeat_oracle_net = true;
      } else if (arg == "--net") {
        assert(i + 1 < argc);
        net_path = argv[++i];
      } else if (arg == "--print_regret") {
        print_regret = true;
      } else if (arg == "--print_regret_summary") {
        print_regret_summary = true;
      } else if (arg == "--no_linear") {
        no_linear = true;
      } else if (arg == "--optimistic") {
        base_params.optimistic = true;
      } else if (arg == "--eval_oracle_values_iters") {
        assert(i + 1 < argc);
        eval_oracle_values_iters = std::stoi(argv[++i]);
      } else if (arg == "--cfr") {
        base_params.use_cfr = true;
      } else if (arg == "--dcfr") {
        base_params.dcfr = true;
        base_params.dcfr_alpha = std::atof(argv[++i]);
        base_params.dcfr_beta = std::atof(argv[++i]);
        base_params.dcfr_gamma = std::atof(argv[++i]);
      } else {
        std::cerr << "Unknown flag: " << arg << "\n";
        return -1;
      }
    }
  }
  assert(num_dice != -1);
  assert(num_faces != -1);

  const Game game(num_dice, num_faces);
  std::cout << "num_dice=" << num_dice << " num_faces=" << num_faces << "\n";
  const auto full_tree = unroll_tree(game);
  std::cout << "Tree of depth " << get_depth(full_tree) << " has "
            << full_tree.size() << " nodes\n";

  std::cout << "##############################################\n";
  std::cout << "##### Solving the game for the full tree #####\n";
  std::cout << "##############################################\n";
  TreeStrategy full_strategy;
  base_params.num_iters = subgame_iters;
  base_params.linear_update = !no_linear && !base_params.dcfr;
  {
    SubgameSolvingParams params = base_params;
    params.max_depth = 100000;
    auto fp = build_solver(game, params);

    std::vector<TreeStrategy> strategy_list;

    for (int iter = 0; iter < subgame_iters; ++iter) {
      fp->step(iter % 2);
      if (iter % 2 == 0 && params.use_cfr) {
        strategy_list.push_back(fp->get_sampling_strategy());
      }
      if (((iter + 1) & iter) == 0 || iter + 1 == subgame_iters) {
        auto values = compute_exploitability2(game, fp->get_strategy());
        printf("Iter=%8d exploitabilities=(%.3e, %.3e) sum=%.3e\n", iter + 1,
               values[0], values[1], (values[0] + values[1]) / 2.);
      }
    }

    full_strategy = fp->get_strategy();
    auto explotabilities = compute_exploitability2(game, full_strategy);
    std::cout << "Full FP exploitability: "
              << (explotabilities[0] + explotabilities[1]) / 2. << " ("
              << explotabilities[0] << "," << explotabilities[1] << ")"
              << std::endl;
    // report_game_stats(game, fp->get_strategy());
    if (!strategy_list.empty()) {
      report_regrets(game, strategy_list, print_regret, print_regret_summary,
                     mdp_depth);
      std::cout << "\n";
    }
    print_strategy(game, fp->get_tree(), fp->get_strategy(),
                   "strategy.full.txt");
  }
  std::vector<std::pair<std::string, TreeStrategy>> all_strategies;
  all_strategies.emplace_back("full_tree", full_strategy);
  if (!net_path.empty()) {
    assert(mdp_depth > 0);
    std::shared_ptr<IValueNet> net =
        net_path == "zero"
            ? liars_dice::create_zero_net(game.num_hands(), false)
            : liars_dice::create_torchscript_net(net_path);

    std::cout << "##############################################\n";
    std::cout << "##### Recursive solving                      #\n";
    std::cout << "##############################################\n";
    if (num_repeats > 0) {
      torch::Tensor summed_stategy, summed_reach;
      TreeStrategy final_strategy;

      auto net_builder = [=]() {
        if (repeat_oracle_net) {
          SubgameSolvingParams oracle_net_params = base_params;
          oracle_net_params.max_depth = 100000;
          if (eval_oracle_values_iters > 0) {
            oracle_net_params.num_iters = eval_oracle_values_iters;
          }
          return liars_dice::create_oracle_value_predictor(game,
                                                           oracle_net_params);
        } else {
          return liars_dice::create_torchscript_net(net_path, "cpu");
        }
      };

      ParallerSampledStrategyComputor computer(game, net_builder, num_repeats,
                                               base_params, mdp_depth,
                                               num_threads, root_only);
      std::vector<TreeStrategy> strategy_list;
      for (int strategy_id = 0; strategy_id < num_repeats; ++strategy_id) {
        {
          auto [sampled_strategy_tensor, node_reach_tensor] =
              computer.get(strategy_id);

          if (strategy_id == 0) {
            summed_stategy = sampled_strategy_tensor * node_reach_tensor;
            summed_reach = node_reach_tensor;
          } else {
            summed_stategy += sampled_strategy_tensor * node_reach_tensor;
            summed_reach += node_reach_tensor;
          }
          if (base_params.use_cfr) {
            strategy_list.push_back(
                tensor_to_tree_strategy(sampled_strategy_tensor));
          }
        }

        final_strategy =
            tensor_to_tree_strategy(summed_stategy / (summed_reach + 1e-6));
        if (((strategy_id + 1) & strategy_id) == 0 ||
            strategy_id + 1 == num_repeats) {
          std::cout << std::setw(5) << strategy_id + 1 << ": ";
          auto explotabilities = compute_exploitability2(game, final_strategy);
          auto evs = compute_ev2(game, full_strategy, final_strategy);
          std::cout << (explotabilities[0] + explotabilities[1]) / 2. << " ("
                    << explotabilities[0] << "," << explotabilities[1] << ")"
                    << "\tEV of full: ";
          std::cout << (evs[0] + evs[1]) / 2 << " (" << evs[0] << "," << evs[1]
                    << ")";
          if (!strategy_list.empty()) {
            report_regrets(game, strategy_list, print_regret,
                           print_regret_summary, mdp_depth);
          }
          std::cout << std::endl;
          const auto name = (repeat_oracle_net ? "repeated oracle toleaf "
                                               : "repeated toleaf ") +
                            std::to_string(strategy_id + 1);
          all_strategies.emplace_back(name, final_strategy);
          print_strategy(game, full_tree, final_strategy,
                         "strategy.repeated.txt");
        }
      }
    }
  }
  // Reporting in human-readable format.
  std::vector<std::pair<std::string, std::string>> result;
  std::vector<std::pair<std::string, std::string>> result_ev;
  result.emplace_back("net", net_path);
  result_ev.emplace_back("net", net_path);
  for (auto [name, mdp_strategy] : all_strategies) {
    std::cout << " " << name << " ";
    assert(mdp_strategy.size() == full_tree.size());
    auto explotabilities = compute_exploitability2(game, mdp_strategy);
    auto evs = compute_ev2(game, full_strategy, mdp_strategy);
    std::cout << (explotabilities[0] + explotabilities[1]) / 2. << " ("
              << explotabilities[0] << "," << explotabilities[1] << ")"
              << "\n\tEV of full: ";
    std::cout << (evs[0] + evs[1]) / 2. << " (" << evs[0] << "," << evs[1]
              << ")";
    std::cout << std::endl;
    result.emplace_back(
        name, std::to_string((explotabilities[0] + explotabilities[1]) / 2.));
    result_ev.emplace_back(name, std::to_string((evs[0] + evs[1]) / 2.));
  }
  // Reporting as JSON.
  std::vector<
      std::tuple<std::string, std::vector<std::pair<std::string, std::string>>>>
      all_results = {{"XXX", result}, {"YYY", result_ev}};
  for (auto [tag, dict] : all_results) {
    std::cout << tag << " {";
    bool first = true;
    for (auto [k, v] : dict) {
      if (first) {
        first = false;
      } else {
        std::cout << ", ";
      }
      std::cout << '"' << k << "\":\"" << v << "\"";
    }
    std::cout << "}" << std::endl;
  }
}