in atari/atari_env.h [19:35]
torch::Tensor getLegalActionMask(ALEInterface& ale, bool useMinAction) {
auto legalAction = ale.getLegalActionSet();
if (!useMinAction) {
auto mask = torch::ones({(int)legalAction.size()}, torch::kFloat32);
return mask;
}
auto mask = torch::zeros({(int)legalAction.size()}, torch::kFloat32);
auto minimalAction = ale.getMinimalActionSet();
auto maskAccessor = mask.accessor<float, 1>();
for (auto action : minimalAction) {
int index = (int)action;
assert(index >= 0 && index < maskAccessor.size(0));
maskAccessor[index] = 1;
}
return mask;
}