Move SearchBot::doSearch_()

in csrc/SearchBot.cc [705:844]


Move SearchBot::doSearch_(
    int who,
    Move bp_move,
    Move frame_move,
    Bot *me_bot,
    const HandDist &handDist,
    const HandDistCDF &cdf,
    SearchStats &stats,
    std::mt19937 &gen,
    const Server &server,
    bool verbose,
    SearchStats *win_stats) const {

  // n.b. the probabilities in handDist may not be right, because it's too
  // slow to update them for public -> private conversion. The probabilities in
  // cdf are considered the ground truth for the purposes of search

  std::vector<Move> moves = enumerateLegalMoves(server);
  int num_moves = moves.size();

  for (auto &move : moves) {
    stats[move] = UCBStats();
    if (win_stats) (*win_stats)[move] = UCBStats();
  }
  stats[bp_move].bias = SEARCH_THRESH;
  std::atomic<int> loop_count(0);
  if (verbose) {
    std::cerr << now() << "search player " << server.whoAmI() << " start" << std::endl;
  }

  bool frame_bail = false;
  int prune_count = 0;
  int bp_mi = -1;
  int frame_mi = -1;
  for (int mi = 0; mi < num_moves; mi++) {
    if (moves.at(mi) == bp_move) {
      bp_mi = mi;
    }
    if (frame_move != Move() && moves.at(mi) == frame_move) {
      frame_mi = mi;
    }
  }

  //We use this to facilitate baseline usage. Shouldn't make a big difference
  int temp_num_threads = NUM_THREADS - (NUM_THREADS % num_moves);
  assert(temp_num_threads >= num_moves);

  //std::cerr << "Temporary number of threads: " << temp_num_threads << std::endl;
  int temp_search_n = SEARCH_N - (SEARCH_N % temp_num_threads);

  std::vector<boost::fibers::future<void>> futures;
  std::mutex mtx;
  Barrier barrier(temp_num_threads);

  std::uniform_int_distribution<int> uid1(0, 1 << 30);
  std::vector<int> seeds(SEARCH_N / num_moves + 1);
  for(int i = 0; i < seeds.size(); i++) seeds[i] = uid1(gen);

  std::vector<int> scores(SEARCH_N, -2);
  int accumed = 0;
  for (int t = 0; t < temp_num_threads; t++) {
    futures.push_back(getThreadPool().enqueue([&, t](){
      for (int j = t; j < temp_search_n; j += temp_num_threads) {
        if (frame_bail || prune_count >= num_moves - 1) {
          break;
        }

        // multi-threaded stuff
        int mi = j % num_moves;
        int g = j / num_moves;
        if (seeds[g] == 0) {
          std::cerr << "WARNING: seed is 0!\n";
        }
        assert(g < seeds.size());
        std::mt19937 my_gen(seeds[g]);

        auto sampled_move = moves.at(mi);
        if (!stats[sampled_move].pruned) {
          loop_count++;
          scores[j] = oneSearchIter_(me_bot, who, sampled_move, cdf, server, handDist, my_gen);
         } else {
          scores[j] = -1; // sentinel
        }

        // single-threaded stuff
        if (UCB && j + temp_num_threads < temp_search_n) {
          barrier.wait();

          if (t == 0) {
            for (int k = j; k < j + temp_num_threads; k++) {
              int bp_score = scores[k - (k % num_moves) + bp_mi];
              accumScore(scores[k], bp_score, moves[k % num_moves], stats, win_stats);
            }

            for (int mi = 0; mi < num_moves; mi++) {
              if (!stats[moves[mi]].pruned && canPruneMove(stats, moves[mi], bp_move)) {
                stats[moves[mi]].pruned = true;
                prune_count++;
                if (moves[mi] == frame_move) {
                  frame_bail = true;
                }
              }
            }
            accumed += temp_num_threads;
          } // if (t == 0)

          barrier.wait();
        }
      }
    }));
  }
  for(auto &f: futures) {
    f.get();
  }
  if (frame_bail) { //Then all that matters is we didn't choose the observed action
    return Move();
  }
  if (prune_count < num_moves - 1) { // accumulate the stragglers
    for (int k = accumed; k < temp_search_n; k++) {
      int bp_score = scores[k - (k % num_moves) + bp_mi];
      accumScore(scores[k], bp_score, moves[k % num_moves], stats, win_stats);
    }
  }

  total_iters_ += loop_count;
  Move best_move;
  double best_score = -1;
  for(auto &kv: stats) {
    if (kv.second.pruned) continue;
    if (kv.second.mean + kv.second.bias > best_score) {
      best_move = kv.first;
      best_score = kv.second.mean + kv.second.bias;
    }
  }
  if (verbose) {
    std::cerr << now() << "Ran " << loop_count << " search iters over " << num_moves << " moves. ( " << server.handsAsString()
              << " ) , p " << server.whoAmI() << " --> " << best_move.toString() << " (" << stats[best_move].mean << ") [bp " << bp_move.toString() << " (" << stats[bp_move].mean << ") ]" << std::endl << std::flush;
  }
  return best_move;
}