in csrc/SearchBot.cc [705:844]
Move SearchBot::doSearch_(
int who,
Move bp_move,
Move frame_move,
Bot *me_bot,
const HandDist &handDist,
const HandDistCDF &cdf,
SearchStats &stats,
std::mt19937 &gen,
const Server &server,
bool verbose,
SearchStats *win_stats) const {
// n.b. the probabilities in handDist may not be right, because it's too
// slow to update them for public -> private conversion. The probabilities in
// cdf are considered the ground truth for the purposes of search
std::vector<Move> moves = enumerateLegalMoves(server);
int num_moves = moves.size();
for (auto &move : moves) {
stats[move] = UCBStats();
if (win_stats) (*win_stats)[move] = UCBStats();
}
stats[bp_move].bias = SEARCH_THRESH;
std::atomic<int> loop_count(0);
if (verbose) {
std::cerr << now() << "search player " << server.whoAmI() << " start" << std::endl;
}
bool frame_bail = false;
int prune_count = 0;
int bp_mi = -1;
int frame_mi = -1;
for (int mi = 0; mi < num_moves; mi++) {
if (moves.at(mi) == bp_move) {
bp_mi = mi;
}
if (frame_move != Move() && moves.at(mi) == frame_move) {
frame_mi = mi;
}
}
//We use this to facilitate baseline usage. Shouldn't make a big difference
int temp_num_threads = NUM_THREADS - (NUM_THREADS % num_moves);
assert(temp_num_threads >= num_moves);
//std::cerr << "Temporary number of threads: " << temp_num_threads << std::endl;
int temp_search_n = SEARCH_N - (SEARCH_N % temp_num_threads);
std::vector<boost::fibers::future<void>> futures;
std::mutex mtx;
Barrier barrier(temp_num_threads);
std::uniform_int_distribution<int> uid1(0, 1 << 30);
std::vector<int> seeds(SEARCH_N / num_moves + 1);
for(int i = 0; i < seeds.size(); i++) seeds[i] = uid1(gen);
std::vector<int> scores(SEARCH_N, -2);
int accumed = 0;
for (int t = 0; t < temp_num_threads; t++) {
futures.push_back(getThreadPool().enqueue([&, t](){
for (int j = t; j < temp_search_n; j += temp_num_threads) {
if (frame_bail || prune_count >= num_moves - 1) {
break;
}
// multi-threaded stuff
int mi = j % num_moves;
int g = j / num_moves;
if (seeds[g] == 0) {
std::cerr << "WARNING: seed is 0!\n";
}
assert(g < seeds.size());
std::mt19937 my_gen(seeds[g]);
auto sampled_move = moves.at(mi);
if (!stats[sampled_move].pruned) {
loop_count++;
scores[j] = oneSearchIter_(me_bot, who, sampled_move, cdf, server, handDist, my_gen);
} else {
scores[j] = -1; // sentinel
}
// single-threaded stuff
if (UCB && j + temp_num_threads < temp_search_n) {
barrier.wait();
if (t == 0) {
for (int k = j; k < j + temp_num_threads; k++) {
int bp_score = scores[k - (k % num_moves) + bp_mi];
accumScore(scores[k], bp_score, moves[k % num_moves], stats, win_stats);
}
for (int mi = 0; mi < num_moves; mi++) {
if (!stats[moves[mi]].pruned && canPruneMove(stats, moves[mi], bp_move)) {
stats[moves[mi]].pruned = true;
prune_count++;
if (moves[mi] == frame_move) {
frame_bail = true;
}
}
}
accumed += temp_num_threads;
} // if (t == 0)
barrier.wait();
}
}
}));
}
for(auto &f: futures) {
f.get();
}
if (frame_bail) { //Then all that matters is we didn't choose the observed action
return Move();
}
if (prune_count < num_moves - 1) { // accumulate the stragglers
for (int k = accumed; k < temp_search_n; k++) {
int bp_score = scores[k - (k % num_moves) + bp_mi];
accumScore(scores[k], bp_score, moves[k % num_moves], stats, win_stats);
}
}
total_iters_ += loop_count;
Move best_move;
double best_score = -1;
for(auto &kv: stats) {
if (kv.second.pruned) continue;
if (kv.second.mean + kv.second.bias > best_score) {
best_move = kv.first;
best_score = kv.second.mean + kv.second.bias;
}
}
if (verbose) {
std::cerr << now() << "Ran " << loop_count << " search iters over " << num_moves << " moves. ( " << server.handsAsString()
<< " ) , p " << server.whoAmI() << " --> " << best_move.toString() << " (" << stats[best_move].mean << ") [bp " << bp_move.toString() << " (" << stats[bp_move].mean << ") ]" << std::endl << std::flush;
}
return best_move;
}