in models/decoders/voxel1.py [0:0]
def forward(self, encoding):
warps = self.warps(encoding).view(encoding.size(0), 16, 3)
warpr = self.warpr(encoding).view(encoding.size(0), 16, 4)
warpt = self.warpt(encoding).view(encoding.size(0), 16, 3) * 0.1
warprot = self.quat(warpr.view(-1, 4)).view(encoding.size(0), 16, 3, 3)
weight = torch.exp(self.weightbranch(encoding).view(encoding.size(0), 16, 32, 32, 32))
warpedweight = torch.cat([
F.grid_sample(weight[:, i:i+1, :, :, :],
torch.sum(((self.grid - warpt[:, None, None, None, i, :])[:, :, :, :, None, :] *
warprot[:, None, None, None, i, :, :]), dim=5) *
warps[:, None, None, None, i, :], padding_mode='border')
for i in range(weight.size(1))], dim=1)
warp = torch.sum(torch.stack([
warpedweight[:, i, :, :, :, None] *
(torch.sum(((self.grid - warpt[:, None, None, None, i, :])[:, :, :, :, None, :] *
warprot[:, None, None, None, i, :, :]), dim=5) *
warps[:, None, None, None, i, :])
for i in range(weight.size(1))], dim=1), dim=1) / torch.sum(warpedweight, dim=1).clamp(min=0.001)[:, :, :, :, None]
return warp.permute(0, 4, 1, 2, 3)