def forward()

in models/volsamplers/warpvoxel.py [0:0]


    def forward(self, pos, template, warp=None, gwarps=None, gwarprot=None, gwarpt=None, viewtemplate=False, **kwargs):
        valid = None
        if not viewtemplate:
            if gwarps is not None:
                pos = (torch.sum(
                    (pos - gwarpt[:, None, None, None, :])[:, :, :, :, None, :] *
                    gwarprot[:, None, None, None, :, :], dim=-1) *
                    gwarps[:, None, None, None, :])
            if warp is not None:
                if self.displacementwarp:
                    pos = pos + F.grid_sample(warp, pos).permute(0, 2, 3, 4, 1)
                else:
                    valid = torch.prod((pos > -1.) * (pos < 1.), dim=-1).float()
                    pos = F.grid_sample(warp, pos).permute(0, 2, 3, 4, 1)
        val = F.grid_sample(template, pos)
        if valid is not None:
            val = val * valid[:, None, :, :, :]
        return val[:, :3, :, :, :], val[:, 3:, :, :, :]