in models/volsamplers/warpvoxel.py [0:0]
def forward(self, pos, template, warp=None, gwarps=None, gwarprot=None, gwarpt=None, viewtemplate=False, **kwargs):
valid = None
if not viewtemplate:
if gwarps is not None:
pos = (torch.sum(
(pos - gwarpt[:, None, None, None, :])[:, :, :, :, None, :] *
gwarprot[:, None, None, None, :, :], dim=-1) *
gwarps[:, None, None, None, :])
if warp is not None:
if self.displacementwarp:
pos = pos + F.grid_sample(warp, pos).permute(0, 2, 3, 4, 1)
else:
valid = torch.prod((pos > -1.) * (pos < 1.), dim=-1).float()
pos = F.grid_sample(warp, pos).permute(0, 2, 3, 4, 1)
val = F.grid_sample(template, pos)
if valid is not None:
val = val * valid[:, None, :, :, :]
return val[:, :3, :, :, :], val[:, 3:, :, :, :]