torch::Tensor computeFeature()

in atari/game_state.h [53:82]


  torch::Tensor computeFeature() {
    torch::Tensor s = getObservation();
    s = s.view({1, 3, height, width});
    s = torch::upsample_bilinear2d(s, {sHeight, sWidth}, true);
    s = s.view({3, sHeight, sWidth});
    // 0.21 * r + 0.72 * g + 0.07 * b
    s = 0.21 * s[0] + 0.72 * s[1] + 0.07 * s[2];
    s = (s * 255.).to(torch::kUInt8);
    assert(s.dim() == 2);
    assert(s.size(0) == sHeight);
    assert(s.size(1) == sWidth);

    if (stackedS_.size() == 0) {
      for (int i = 0; i < frameStack; ++i) {
        stackedS_.push_back(s);
      }
    } else {
      assert((int)stackedS_.size() == frameStack);
      stackedS_.pop_front();
      stackedS_.push_back(s);
    }
    assert((int)stackedS_.size() == frameStack);

    torch::Tensor obs =
        torch::zeros({frameStack, sHeight, sWidth}, torch::kUInt8);
    for (int i = 0; i < frameStack; ++i) {
      obs[i].copy_(stackedS_[i]);
    }
    return obs;
  }