torch::Tensor getLegalActionMask()

in atari/atari_env.h [19:35]


torch::Tensor getLegalActionMask(ALEInterface& ale, bool useMinAction) {
  auto legalAction = ale.getLegalActionSet();
  if (!useMinAction) {
    auto mask = torch::ones({(int)legalAction.size()}, torch::kFloat32);
    return mask;
  }

  auto mask = torch::zeros({(int)legalAction.size()}, torch::kFloat32);
  auto minimalAction = ale.getMinimalActionSet();
  auto maskAccessor = mask.accessor<float, 1>();
  for (auto action : minimalAction) {
    int index = (int)action;
    assert(index >= 0 && index < maskAccessor.size(0));
    maskAccessor[index] = 1;
  }
  return mask;
}