void R2D2Actor::reset()

in rlcc/r2d2_actor.cc [127:173]


void R2D2Actor::reset(const HanabiEnv& env) {
  hidden_ = getH0(batchsize_, runner_);
  if (beliefRunner_ != nullptr) {
    beliefHidden_ = getH0(batchsize_, beliefRunner_);
  }

  if (r2d2Buffer_ != nullptr) {
    r2d2Buffer_->init(hidden_);
  }

  const auto& game = env.getHleGame();
  int fixColorPlayer = -1;
  if (vdn_ && shuffleColor_) {
    fixColorPlayer = rng_() % game.NumPlayers();
  }

  for (int i = 0; i < batchsize_; ++i) {
    assert(playerEps_.size() > 0 && epsList_.size() > 0);
    playerEps_[i] = epsList_[rng_() % epsList_.size()];
    if (tempList_.size() > 0) {
      assert(playerTemp_.size() > 0);
      playerTemp_[i] = tempList_[rng_() % tempList_.size()];
    }

    // other-play
    if (shuffleColor_) {
      auto& colorPermute = colorPermutes_[i];
      auto& invColorPermute = invColorPermutes_[i];
      colorPermute.clear();
      invColorPermute.clear();
      for (int i = 0; i < game.NumColors(); ++i) {
        colorPermute.push_back(i);
        invColorPermute.push_back(i);
      }
      if (i == fixColorPlayer) {
        continue;
      }
      std::shuffle(colorPermute.begin(), colorPermute.end(), rng_);
      std::sort(invColorPermute.begin(), invColorPermute.end(), [&](int i, int j) {
        return colorPermute[i] < colorPermute[j];
      });
      for (int i = 0; i < (int)colorPermute.size(); ++i) {
        assert(invColorPermute[colorPermute[i]] == i);
      }
    }
  }
}