in atari/atari_env.h [113:155]
std::tuple<rela::TensorDict, float, bool> step(
const rela::TensorDict& action) final {
// take an ale step
torch::Tensor a = action.at("a");
float reward = aleStep(a.item<int>());
// update state
state_->addReward(reward);
state_->setLives(ale_->lives());
if (ale_->game_over() || numSteps_ * frameSkip_ > maxNumFrame_) {
state_->setTerminal();
}
// compute reward and terminal signal
float clippedReward = clipRewards(reward);
bool terminalSignal =
state_->terminal() || (terminalSignalOnLifeLoss_ && state_->lostLife());
if (state_->terminal()) {
// state should not matter, but we still need to send it
torch::Tensor obs = state_->computeFeature();
rela::TensorDict input = {
{"s", obs},
{"eps", exploreEps_},
{"legal_move", legalActionMask_}
};
return std::make_tuple(input, clippedReward, true);
}
// if lost life (but game is not over) need to press start key again
if (state_->lostLife()) {
pressStartKey();
}
// compute obs
ale_->getScreenRGB(state_->getObservationBuffer());
torch::Tensor obs = state_->computeFeature();
rela::TensorDict input = {
{"s", obs},
{"eps", exploreEps_},
{"legal_move", legalActionMask_}
};
return std::make_tuple(input, clippedReward, terminalSignal);
}