def forward()

in minihack/agent/polybeast/models/base.py [0:0]


    def forward(self, inputs, coordinates):
        """Calculates centered crop around given x,y coordinates.

        Args:
           inputs [B x H x W] or [B x H x W x C]
           coordinates [B x 2] x,y coordinates

        Returns:
           [B x H' x W'] inputs cropped and centered around x,y coordinates.
        """
        assert inputs.shape[1] == self.height, "expected %d but found %d" % (
            self.height,
            inputs.shape[1],
        )
        assert inputs.shape[2] == self.width, "expected %d but found %d" % (
            self.width,
            inputs.shape[2],
        )

        permute_results = False
        if inputs.dim() == 3:
            inputs = inputs.unsqueeze(1)
        else:
            permute_results = True
            inputs = inputs.permute(0, 2, 3, 1)
        inputs = inputs.float()

        x = coordinates[:, 0]
        y = coordinates[:, 1]

        x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
        y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)

        grid = torch.stack(
            [
                self.width_grid[None, :, :] + x_shift[:, None, None],
                self.height_grid[None, :, :] + y_shift[:, None, None],
            ],
            dim=3,
        )

        crop = (
            torch.round(F.grid_sample(inputs, grid, align_corners=True))
            .squeeze(1)
            .long()
        )

        if permute_results:
            # [B x C x H x W] -> [B x H x W x C]
            crop = crop.permute(0, 2, 3, 1)

        return crop