int main()

in csrc/liars_dice/gen_benchmark.cc [54:154]


int main(int argc, char* argv[]) {
  int num_dice = 1;
  int num_faces = 4;
  int fp_iters = 1024;
  int mdp_depth = 2;
  int num_threads = 10;
  int per_gpu = 1;
  int num_cycles = 6;
  std::string device = "cuda:1";
  std::string net_path;
  {
    for (int i = 1; i < argc; i++) {
      std::string arg = argv[i];
      if (arg == "--num_dice") {
        assert(i + 1 < argc);
        num_dice = std::stoi(argv[++i]);
      } else if (arg == "--num_faces") {
        assert(i + 1 < argc);
        num_faces = std::stoi(argv[++i]);
      } else if (arg == "--fp_iters") {
        assert(i + 1 < argc);
        fp_iters = std::stoi(argv[++i]);
      } else if (arg == "--mdp_depth") {
        assert(i + 1 < argc);
        mdp_depth = std::stoi(argv[++i]);
      } else if (arg == "--num_threads") {
        assert(i + 1 < argc);
        num_threads = std::stoi(argv[++i]);
      } else if (arg == "--per_gpu") {
        assert(i + 1 < argc);
        per_gpu = std::stoi(argv[++i]);
      } else if (arg == "--num_cycles") {
        assert(i + 1 < argc);
        num_cycles = std::stoi(argv[++i]);
      } else if (arg == "--device") {
        assert(i + 1 < argc);
        device = argv[++i];
      } else if (arg == "--net") {
        assert(i + 1 < argc);
        net_path = argv[++i];
      } else {
        std::cerr << "Unknown flag: " << arg << "\n";
        return -1;
      }
    }
  }
  assert(num_dice != -1);
  assert(num_faces != -1);
  assert(mdp_depth != -1);

  const Game game(num_dice, num_faces);
  assert(mdp_depth > 0);
  assert(!net_path.empty());
  std::cout << "num_dice=" << num_dice << " num_faces=" << num_faces << "\n";
  {
    const auto full_tree = unroll_tree(game);
    std::cout << "Tree of depth " << get_depth(full_tree) << " has "
              << full_tree.size() << " nodes\n";
  }

  std::vector<TorchJitModel> models;
  for (int i = 0; i < per_gpu; ++i) {
    auto module = torch::jit::load(net_path);
    module.eval();
    module.to(device);
    models.push_back(module);
  }
  std::vector<TorchJitModel*> model_ptrs;
  for (int i = 0; i < per_gpu; ++i) {
    model_ptrs.push_back(&models[i]);
  }
  auto locker = std::make_shared<ModelLocker>(model_ptrs, device);
  auto replay = std::make_shared<ValuePrioritizedReplay>(1 << 20, 1000, 1.0,
                                                         0.4, 3, false, false);
  auto context = std::make_shared<Context>();

  RecursiveSolvingParams cfg;
  cfg.num_dice = num_dice;
  cfg.num_faces = num_faces;
  cfg.subgame_params.num_iters = fp_iters;
  cfg.subgame_params.linear_update = true;
  cfg.subgame_params.optimistic = false;
  cfg.subgame_params.max_depth = mdp_depth;
  for (int i = 0; i < num_threads; ++i) {
    const int seed = i;
    auto connector = std::make_shared<CVNetBufferConnector>(locker, replay);
    std::shared_ptr<ThreadLoop> loop =
        std::make_shared<DataThreadLoop>(std::move(connector), cfg, seed);
    context->pushThreadLoop(loop);
  }
  std::cout << "Starting the context" << std::endl;
  context->start();
  Timer t;
  for (int i = 0; i < num_cycles; ++i) {
    std::this_thread::sleep_for(std::chrono::seconds(10));
    double secs = t.tick();
    auto added = replay->numAdd();
    std::cout << "time=" << secs << " "
              << "items=" << added << " per_second=" << added / secs << "\n";
  }
}