in minihack/agent/polybeast/models/base.py [0:0]
def forward(self, inputs, coordinates):
"""Calculates centered crop around given x,y coordinates.
Args:
inputs [B x H x W] or [B x H x W x C]
coordinates [B x 2] x,y coordinates
Returns:
[B x H' x W'] inputs cropped and centered around x,y coordinates.
"""
assert inputs.shape[1] == self.height, "expected %d but found %d" % (
self.height,
inputs.shape[1],
)
assert inputs.shape[2] == self.width, "expected %d but found %d" % (
self.width,
inputs.shape[2],
)
permute_results = False
if inputs.dim() == 3:
inputs = inputs.unsqueeze(1)
else:
permute_results = True
inputs = inputs.permute(0, 2, 3, 1)
inputs = inputs.float()
x = coordinates[:, 0]
y = coordinates[:, 1]
x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)
grid = torch.stack(
[
self.width_grid[None, :, :] + x_shift[:, None, None],
self.height_grid[None, :, :] + y_shift[:, None, None],
],
dim=3,
)
crop = (
torch.round(F.grid_sample(inputs, grid, align_corners=True))
.squeeze(1)
.long()
)
if permute_results:
# [B x C x H x W] -> [B x H x W x C]
crop = crop.permute(0, 2, 3, 1)
return crop