def forward()

in models/decoders/voxel1.py [0:0]


    def forward(self, encoding):
        warps = self.warps(encoding).view(encoding.size(0), 16, 3)
        warpr = self.warpr(encoding).view(encoding.size(0), 16, 4)
        warpt = self.warpt(encoding).view(encoding.size(0), 16, 3) * 0.1
        warprot = self.quat(warpr.view(-1, 4)).view(encoding.size(0), 16, 3, 3)

        weight = torch.exp(self.weightbranch(encoding).view(encoding.size(0), 16, 32, 32, 32))

        warpedweight = torch.cat([
            F.grid_sample(weight[:, i:i+1, :, :, :],
                torch.sum(((self.grid - warpt[:, None, None, None, i, :])[:, :, :, :, None, :] *
                    warprot[:, None, None, None, i, :, :]), dim=5) *
                    warps[:, None, None, None, i, :], padding_mode='border')
            for i in range(weight.size(1))], dim=1)

        warp = torch.sum(torch.stack([
            warpedweight[:, i, :, :, :, None] *
            (torch.sum(((self.grid - warpt[:, None, None, None, i, :])[:, :, :, :, None, :] *
                warprot[:, None, None, None, i, :, :]), dim=5) *
                warps[:, None, None, None, i, :])
            for i in range(weight.size(1))], dim=1), dim=1) / torch.sum(warpedweight, dim=1).clamp(min=0.001)[:, :, :, :, None]

        return warp.permute(0, 4, 1, 2, 3)