in simple_game/search.h [1059:1257]
AlgResult runSearch(int playerIdx,
int numIteration,
InfoSetsSampler& sampler) {
assert(root_->infoSet().isChance());
float lastBest = 0.0f;
AlgResult result;
result.bestSoFar = std::numeric_limits<float>::lowest();
sampler.reset();
for (int k = 0; k < numIteration; ++k) {
bool perturbed = false;
if (options_.perturbChance > 0) {
perturbed = true;
manager_.perturbChance(options_.perturbChance);
}
if (options_.perturbPolicy > 0) {
perturbed = true;
manager_.perturbPolicy(options_.perturbPolicy);
}
evaluate();
const auto& u = root_->u();
float baseScore = u[playerIdx];
if (!perturbed)
result.bestSoFar = std::max(result.bestSoFar, baseScore);
if (k > 0 && std::abs(baseScore - lastBest) >= 1e-6 &&
options_.numSample == 0 && options_.numSampleTotal == 0 &&
options_.perturbChance == 0 && options_.perturbPolicy == 0) {
if (options_.verbose != SILENT) {
std::cout << "Potential err! lastBest [" << lastBest << "]"
<< " != baseScore [" << baseScore << "]" << std::endl;
}
}
if (k == 0) {
lastBest = baseScore;
if (options_.verbose != SILENT) {
std::cout << "Score before optimization: " << baseScore << std::endl;
}
}
if (options_.numSample > 0 || options_.numSampleTotal > 0) {
if (std::abs(baseScore - lastBest) >= 1e-6) {
// sampled based approach may estimate score wrong.
lastBest = baseScore;
}
if (options_.verbose != SILENT) {
std::cout << "[" << k << "]: full eval score: " << baseScore
<< std::endl;
}
}
// Which infoSets we want to use?
InfoSets infoSets = sampler.sample();
if (infoSets.empty())
break;
/*
if (numSamples > 0) {
if (options_.firstRandomInfoSetKey != "" && k == 0) {
keys.clear();
keys.push_back(options_.firstRandomInfoSetKey);
} else {
// Random pick one.
std::random_shuffle(keys.begin(), keys.end());
keys.erase(keys.begin() + 1, keys.end());
}
}
We could also label each states and its descendents to be active, if we
want to do sample-based approach.
*/
/*
for (const auto& key : keys) {
auto samples = manager_.drawSamples(key, numSamples);
auto res2 = manager_.enumPoliciesSamples(samples, playerIdx);
resultSampling.combine(res2);
}
*/
Analysis analysis;
Stats stats;
auto start = std::chrono::high_resolution_clock::now();
auto resultSampling =
_search2({},
infoSets,
playerIdx,
options_.computeReach ? &analysis : nullptr,
&stats);
resultSampling.addBias(baseScore);
auto stop = std::chrono::high_resolution_clock::now();
float searchTime =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start)
.count() /
1e6;
if (options_.verbose == VERBOSE) {
std::cout << "candidates from search: " << std::endl
<< resultSampling.info(manager_) << std::endl;
}
auto best = resultSampling.getBest();
bool improved = false;
if (best.value - lastBest > 1e-4) {
improved = true;
if (options_.showBetter) {
manager_.printStrategy();
}
sampler.reset();
}
if (options_.verbose != SILENT) {
std::cout << "[" << k << ":search]: time: " << searchTime
<< " on states: " << stats.time_states;
if (improved)
std::cout << " result(*): ";
else
std::cout << " result: ";
std::cout << best.info(manager_) << std::endl;
}
lastBest = best.value;
if (options_.gtCompute) {
Analysis analysisGt;
auto start = std::chrono::high_resolution_clock::now();
auto resultBruteForce = _bruteforceSearchJointInfoSet(
{},
infoSets,
playerIdx,
options_.computeReach ? &analysisGt : nullptr);
auto stop = std::chrono::high_resolution_clock::now();
float bruteForceTime =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start)
.count() /
1e6;
auto bestBruteForce = resultBruteForce.getBest();
if (options_.verbose != SILENT) {
std::cout << "[" << k << ":brute ]: time: " << bruteForceTime
<< " result: " << bestBruteForce.info(manager_)
<< std::endl;
if (std::abs(bestBruteForce.value - best.value) >= 1e-4) {
std::cout << "Warning! search value [" << best.value
<< "] != bruteForce value [" << bestBruteForce.value
<< "]" << std::endl;
}
if (options_.verbose == VERBOSE) {
std::cout << "Search terms: " << std::endl;
std::cout << analysis.terms.info(manager_, false) << std::endl;
}
}
if (options_.gtOverride) {
if (options_.verbose != SILENT) {
std::cout << "Overriding search with bruteForce!" << std::endl;
}
best = bestBruteForce;
lastBest = best.value;
}
if (options_.verbose == VERBOSE) {
std::cout << "candidates from bruteForce: " << std::endl
<< resultBruteForce.info(manager_) << std::endl;
}
if (options_.computeReach) {
// Compare the difference of the two reachability.
analysisGt.compareReach(analysis);
}
}
if (options_.numSample == 0 && options_.numSampleTotal == 0 &&
options_.perturbChance == 0 && best.value == baseScore &&
options_.maxDepth == 0)
break;
// Change the policy based on best policy.
// Loop over its involved infos and change their actions.
for (const auto& infoAction : best.actions) {
manager_[infoAction.first].setDeltaStrategy(infoAction.second);
}
}
sampler.reset();
manager_.perturbChance(0);
evaluate();
result.lastU = root_->u();
result.bestSoFar = std::max(result.bestSoFar, result.lastU[playerIdx]);
return result;
}