AlgResult runSearch()

in simple_game/search.h [1059:1257]


  AlgResult runSearch(int playerIdx,
                      int numIteration,
                      InfoSetsSampler& sampler) {
    assert(root_->infoSet().isChance());
    float lastBest = 0.0f;

    AlgResult result;
    result.bestSoFar = std::numeric_limits<float>::lowest();

    sampler.reset();

    for (int k = 0; k < numIteration; ++k) {
      bool perturbed = false;
      if (options_.perturbChance > 0) {
        perturbed = true;
        manager_.perturbChance(options_.perturbChance);
      }

      if (options_.perturbPolicy > 0) {
        perturbed = true;
        manager_.perturbPolicy(options_.perturbPolicy);
      }

      evaluate();
      const auto& u = root_->u();

      float baseScore = u[playerIdx];
      if (!perturbed)
        result.bestSoFar = std::max(result.bestSoFar, baseScore);

      if (k > 0 && std::abs(baseScore - lastBest) >= 1e-6 &&
          options_.numSample == 0 && options_.numSampleTotal == 0 &&
          options_.perturbChance == 0 && options_.perturbPolicy == 0) {
        if (options_.verbose != SILENT) {
          std::cout << "Potential err! lastBest [" << lastBest << "]"
                    << " != baseScore [" << baseScore << "]" << std::endl;
        }
      }

      if (k == 0) {
        lastBest = baseScore;
        if (options_.verbose != SILENT) {
          std::cout << "Score before optimization: " << baseScore << std::endl;
        }
      }

      if (options_.numSample > 0 || options_.numSampleTotal > 0) {
        if (std::abs(baseScore - lastBest) >= 1e-6) {
          // sampled based approach may estimate score wrong.
          lastBest = baseScore;
        }
        if (options_.verbose != SILENT) {
          std::cout << "[" << k << "]: full eval score: " << baseScore
                    << std::endl;
        }
      }

      // Which infoSets we want to use?
      InfoSets infoSets = sampler.sample();
      if (infoSets.empty())
        break;

      /*
      if (numSamples > 0) {
        if (options_.firstRandomInfoSetKey != "" && k == 0) {
          keys.clear();
          keys.push_back(options_.firstRandomInfoSetKey);
        } else {
          // Random pick one.
          std::random_shuffle(keys.begin(), keys.end());
          keys.erase(keys.begin() + 1, keys.end());
        }
      }
      We could also label each states and its descendents to be active, if we
      want to do sample-based approach.
      */
      /*
      for (const auto& key : keys) {
        auto samples = manager_.drawSamples(key, numSamples);
        auto res2 = manager_.enumPoliciesSamples(samples, playerIdx);
        resultSampling.combine(res2);
      }
      */
      Analysis analysis;
      Stats stats;
      auto start = std::chrono::high_resolution_clock::now();
      auto resultSampling =
          _search2({},
                   infoSets,
                   playerIdx,
                   options_.computeReach ? &analysis : nullptr,
                   &stats);
      resultSampling.addBias(baseScore);
      auto stop = std::chrono::high_resolution_clock::now();
      float searchTime =
          std::chrono::duration_cast<std::chrono::microseconds>(stop - start)
              .count() /
          1e6;

      if (options_.verbose == VERBOSE) {
        std::cout << "candidates from search: " << std::endl
                  << resultSampling.info(manager_) << std::endl;
      }

      auto best = resultSampling.getBest();

      bool improved = false;
      if (best.value - lastBest > 1e-4) {
        improved = true;
        if (options_.showBetter) {
          manager_.printStrategy();
        }
        sampler.reset();
      }

      if (options_.verbose != SILENT) {
        std::cout << "[" << k << ":search]: time: " << searchTime
                  << " on states: " << stats.time_states;
        if (improved)
          std::cout << " result(*): ";
        else
          std::cout << " result: ";
        std::cout << best.info(manager_) << std::endl;
      }
      lastBest = best.value;

      if (options_.gtCompute) {
        Analysis analysisGt;

        auto start = std::chrono::high_resolution_clock::now();
        auto resultBruteForce = _bruteforceSearchJointInfoSet(
            {},
            infoSets,
            playerIdx,
            options_.computeReach ? &analysisGt : nullptr);
        auto stop = std::chrono::high_resolution_clock::now();
        float bruteForceTime =
            std::chrono::duration_cast<std::chrono::microseconds>(stop - start)
                .count() /
            1e6;

        auto bestBruteForce = resultBruteForce.getBest();

        if (options_.verbose != SILENT) {
          std::cout << "[" << k << ":brute ]: time: " << bruteForceTime
                    << " result: " << bestBruteForce.info(manager_)
                    << std::endl;

          if (std::abs(bestBruteForce.value - best.value) >= 1e-4) {
            std::cout << "Warning! search value [" << best.value
                      << "] != bruteForce value [" << bestBruteForce.value
                      << "]" << std::endl;
          }

          if (options_.verbose == VERBOSE) {
            std::cout << "Search terms: " << std::endl;
            std::cout << analysis.terms.info(manager_, false) << std::endl;
          }
        }

        if (options_.gtOverride) {
          if (options_.verbose != SILENT) {
            std::cout << "Overriding search with bruteForce!" << std::endl;
          }
          best = bestBruteForce;
          lastBest = best.value;
        }

        if (options_.verbose == VERBOSE) {
          std::cout << "candidates from bruteForce: " << std::endl
                    << resultBruteForce.info(manager_) << std::endl;
        }

        if (options_.computeReach) {
          // Compare the difference of the two reachability.
          analysisGt.compareReach(analysis);
        }
      }

      if (options_.numSample == 0 && options_.numSampleTotal == 0 &&
          options_.perturbChance == 0 && best.value == baseScore &&
          options_.maxDepth == 0)
        break;

      // Change the policy based on best policy.
      // Loop over its involved infos and change their actions.
      for (const auto& infoAction : best.actions) {
        manager_[infoAction.first].setDeltaStrategy(infoAction.second);
      }
    }

    sampler.reset();

    manager_.perturbChance(0);
    evaluate();
    result.lastU = root_->u();
    result.bestSoFar = std::max(result.bestSoFar, result.lastU[playerIdx]);
    return result;
  }