in rlcc/r2d2_actor.cc [127:173]
void R2D2Actor::reset(const HanabiEnv& env) {
hidden_ = getH0(batchsize_, runner_);
if (beliefRunner_ != nullptr) {
beliefHidden_ = getH0(batchsize_, beliefRunner_);
}
if (r2d2Buffer_ != nullptr) {
r2d2Buffer_->init(hidden_);
}
const auto& game = env.getHleGame();
int fixColorPlayer = -1;
if (vdn_ && shuffleColor_) {
fixColorPlayer = rng_() % game.NumPlayers();
}
for (int i = 0; i < batchsize_; ++i) {
assert(playerEps_.size() > 0 && epsList_.size() > 0);
playerEps_[i] = epsList_[rng_() % epsList_.size()];
if (tempList_.size() > 0) {
assert(playerTemp_.size() > 0);
playerTemp_[i] = tempList_[rng_() % tempList_.size()];
}
// other-play
if (shuffleColor_) {
auto& colorPermute = colorPermutes_[i];
auto& invColorPermute = invColorPermutes_[i];
colorPermute.clear();
invColorPermute.clear();
for (int i = 0; i < game.NumColors(); ++i) {
colorPermute.push_back(i);
invColorPermute.push_back(i);
}
if (i == fixColorPlayer) {
continue;
}
std::shuffle(colorPermute.begin(), colorPermute.end(), rng_);
std::sort(invColorPermute.begin(), invColorPermute.end(), [&](int i, int j) {
return colorPermute[i] < colorPermute[j];
});
for (int i = 0; i < (int)colorPermute.size(); ++i) {
assert(invColorPermute[colorPermute[i]] == i);
}
}
}
}