in models/decoders/nv.py [0:0]
def forward(self,
encoding,
viewpos,
condinput : Optional[torch.Tensor]=None,
renderoptions : Optional[Dict[str, str]]=None,
trainiter : int=-1,
evaliter : Optional[torch.Tensor]=None,
losslist : Optional[List[str]]=None,
modelmatrix : Optional[torch.Tensor]=None):
"""
Parameters
----------
encoding : torch.Tensor [B, 256]
Encoding of current frame
viewpos : torch.Tensor [B, 3]
Viewing position of target camera view
condinput : torch.Tensor [B, ?]
Additional conditioning input (e.g., headpose)
renderoptions : dict
Options for rendering (e.g., rendering debug images)
trainiter : int,
Current training iteration
losslist : list,
List of losses to compute and return
Returns
-------
result : dict,
Contains predicted vertex positions, primitive contents and
locations, scaling, and orientation, and any losses.
"""
assert renderoptions is not None
assert losslist is not None
if condinput is not None:
encoding = torch.cat([encoding, condinput], dim=1)
encoding = self.enc(encoding)
viewdirs = F.normalize(viewpos, dim=1)
primpos = torch.zeros(encoding.size(0), 1, 3, device=encoding.device)
primrot = torch.eye(3, device=encoding.device)[None, None, :, :].repeat(encoding.size(0), 1, 1, 1)
primscale = torch.ones(encoding.size(0), 1, 3, device=encoding.device)
# options
algo = renderoptions.get("algo")
chlast = renderoptions.get("chlast")
half = renderoptions.get("half")
if self.rgbadec is not None:
# shared rgb and alpha branch
scale = torch.tensor([25., 25., 25., 1.], device=encoding.device)
bias = torch.tensor([100., 100., 100., 0.], device=encoding.device)
if chlast is not None and bool(chlast):
scale = scale[None, None, None, None, None, :]
bias = bias[None, None, None, None, None, :]
else:
scale = scale[None, None, :, None, None, None]
bias = bias[None, None, :, None, None, None]
templatein = torch.cat([encoding, viewdirs], dim=1)
if half is not None and bool(half):
templatein = templatein.half()
template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
template = bias + scale * template
if not self.notplateact:
template = F.relu(template)
if half is not None and bool(half):
template = template.float()
else:
templatein = torch.cat([encoding, viewdirs], dim=1)
if half is not None and bool(half):
templatein = templatein.half()
primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions)
primrgb = primrgb * 25. + 100.
if not self.notplateact:
primrgb = F.relu(primrgb)
templatein = encoding
if half is not None and bool(half):
templatein = templatein.half()
primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
if not self.notplateact:
primalpha = F.relu(primalpha)
if trainiter <= self.alphatrainstart:
primalpha = primalpha * 0. + 1.
if algo is not None and int(algo) == 4:
template = torch.cat([primrgb, primalpha], dim=-1)
elif chlast is not None and bool(chlast):
template = torch.cat([primrgb, primalpha], dim=-1)
else:
template = torch.cat([primrgb, primalpha], dim=2)
if half is not None and bool(half):
template = template.float()
if self.warpdec is not None:
warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01
warp = warp + torch.stack(torch.meshgrid(
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device),
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device),
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device))[::-1],
dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :]
warp = warp.contiguous()
else:
warp = None
losses = {}
# prior on primitive volume
if "primvolsum" in losslist:
losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1)
if "logprimscalevar" in losslist:
logprimscale = torch.log(primscale)
logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True)
losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2)
result = {
"template": template,
"primpos": primpos,
"primrot": primrot,
"primscale": primscale}
if warp is not None:
result["warp"] = warp
return result, losses