in models/decoders/voxel1.py [0:0]
def forward(self, encoding, viewpos, losslist=[]):
scale = torch.tensor([25., 25., 25., 1.], device=encoding.device)[None, :, None, None, None]
bias = torch.tensor([100., 100., 100., 0.], device=encoding.device)[None, :, None, None, None]
# run template branch
viewdir = viewpos / torch.sqrt(torch.sum(viewpos ** 2, dim=-1, keepdim=True))
templatein = torch.cat([encoding, viewdir], dim=1) if self.viewconditioned else encoding
template = self.template(templatein)
if self.viewconditioned:
# run alpha branch without viewpoint information
template = torch.cat([template, self.templatealpha(encoding)], dim=1)
# scale up to 0-255 range approximately
template = F.softplus(bias + scale * template)
# compute warp voxel field
warp = self.warp(encoding) if self.warp is not None else None
if self.globalwarp:
# compute single affine transformation
gwarps = 1.0 * torch.exp(0.05 * self.gwarps(encoding).view(encoding.size(0), 3))
gwarpr = self.gwarpr(encoding).view(encoding.size(0), 4) * 0.1
gwarpt = self.gwarpt(encoding).view(encoding.size(0), 3) * 0.025
gwarprot = self.quat(gwarpr.view(-1, 4)).view(encoding.size(0), 3, 3)
losses = {}
# tv-L1 prior
if "tvl1" in losslist:
logalpha = torch.log(1e-5 + template[:, -1, :, :, :])
losses["tvl1"] = torch.mean(torch.sqrt(1e-5 +
(logalpha[:, :-1, :-1, 1:] - logalpha[:, :-1, :-1, :-1]) ** 2 +
(logalpha[:, :-1, 1:, :-1] - logalpha[:, :-1, :-1, :-1]) ** 2 +
(logalpha[:, 1:, :-1, :-1] - logalpha[:, :-1, :-1, :-1]) ** 2))
return {"template": template, "warp": warp,
**({"gwarps": gwarps, "gwarprot": gwarprot, "gwarpt": gwarpt} if self.globalwarp else {}),
"losses": losses}