in atari/game_state.h [53:82]
torch::Tensor computeFeature() {
torch::Tensor s = getObservation();
s = s.view({1, 3, height, width});
s = torch::upsample_bilinear2d(s, {sHeight, sWidth}, true);
s = s.view({3, sHeight, sWidth});
// 0.21 * r + 0.72 * g + 0.07 * b
s = 0.21 * s[0] + 0.72 * s[1] + 0.07 * s[2];
s = (s * 255.).to(torch::kUInt8);
assert(s.dim() == 2);
assert(s.size(0) == sHeight);
assert(s.size(1) == sWidth);
if (stackedS_.size() == 0) {
for (int i = 0; i < frameStack; ++i) {
stackedS_.push_back(s);
}
} else {
assert((int)stackedS_.size() == frameStack);
stackedS_.pop_front();
stackedS_.push_back(s);
}
assert((int)stackedS_.size() == frameStack);
torch::Tensor obs =
torch::zeros({frameStack, sHeight, sWidth}, torch::kUInt8);
for (int i = 0; i < frameStack; ++i) {
obs[i].copy_(stackedS_[i]);
}
return obs;
}