in rlcc/clone_data_generator.cc [30:103]
void DataGenLoop::mainLoop() {
assert(gameDatas_.size() > 0);
std::vector<size_t> idxsLeft;
while (!terminated()) {
if (idxsLeft.size() <= 0) {
if (!infLoop_) {
if (epoch_ == 0) {
++epoch_;
} else {
break;
}
}
idxsLeft.resize(gameDatas_.size());
std::iota(idxsLeft.begin(), idxsLeft.end(), 0);
std::shuffle(idxsLeft.begin(), idxsLeft.end(), rng_);
}
size_t idx = idxsLeft.back();
idxsLeft.pop_back();
auto gameData = gameDatas_[idx];
HanabiEnv env(gameParams_, maxLen_, false);
env.resetWithDeck(gameData.deck_);
auto& state = env.getHleState();
if (shuffleColor_) {
shuffleColor(env.getHleGame());
}
for (size_t midx = 0; midx < gameData.moves_.size(); ++midx) {
auto move = gameData.moves_[midx];
int curPlayer = env.getCurrentPlayer();
for (int i = 0; i < numPlayer_; ++i) {
auto obs = observe(
state,
i,
shuffleColor_,
colorPermutes_[i],
invColorPermutes_[i],
false, // hideAction
trinary_, // trinary for aux task
false); // sad
r2d2Buffers_[i].pushObs(obs);
int action = -1;
if (i == curPlayer) {
if (shuffleColor_ && move.MoveType() == hle::HanabiMove::kRevealColor) {
auto shuffledMove = move;
shuffledMove.SetColor(colorPermutes_[i][move.Color()]);
action = env.getHleGame().GetMoveUid(shuffledMove);
} else {
action = env.getHleGame().GetMoveUid(move);
}
} else {
action = env.noOpUid();
}
r2d2Buffers_[i].pushAction({{"a", torch::tensor(action)}});
}
env.step(move);
float reward = env.stepReward();
float terminal = env.terminated();
if (midx == gameData.moves_.size() - 1) {
terminal = true;
}
for (int i = 0; i < numPlayer_; ++i) {
r2d2Buffers_[i].pushReward(reward);
r2d2Buffers_[i].pushTerminal(terminal);
}
}
for (int i = 0; i < numPlayer_; ++i) {
replayBuffer_->add(r2d2Buffers_[i].popTransition(), 1.0);
}
} // while (!terminated())
};