in models/decoders/mvp.py [0:0]
def forward(self,
encoding,
viewpos,
condinput : Optional[torch.Tensor]=None,
renderoptions : Optional[Dict[str, str]]=None,
trainiter : int=-1,
evaliter : Optional[torch.Tensor]=None,
losslist : Optional[List[str]]=None,
modelmatrix : Optional[torch.Tensor]=None):
"""
Parameters
----------
encoding : torch.Tensor [B, 256]
Encoding of current frame
viewpos : torch.Tensor [B, 3]
Viewing position of target camera view
condinput : torch.Tensor [B, ?]
Additional conditioning input (e.g., headpose)
renderoptions : dict
Options for rendering (e.g., rendering debug images)
trainiter : int,
Current training iteration
losslist : list,
List of losses to compute and return
Returns
-------
result : dict,
Contains predicted vertex positions, primitive contents and
locations, scaling, and orientation, and any losses.
"""
assert renderoptions is not None
assert losslist is not None
if condinput is not None:
encoding = torch.cat([encoding, condinput], dim=1)
encoding = self.enc(encoding)
viewdirs = F.normalize(viewpos, dim=1)
if int(math.sqrt(self.nprims)) ** 2 == self.nprims:
nprimsy = int(math.sqrt(self.nprims))
else:
nprimsy = int(math.sqrt(self.nprims // 2))
nprimsx = self.nprims // nprimsy
assert nprimsx * nprimsy == self.nprims
if not self.nogeo:
# decode mesh vertices
geo = self.geobranch(encoding)
geo = geo.view(encoding.size(0), -1, 3)
geo = geo * self.vertstd + self.vertmean
# placement of primitives on mesh
uvheight, uvwidth = self.barim.size(0), self.barim.size(1)
stridey = uvheight // nprimsy
stridex = uvwidth // nprimsx
# get subset of vertices and texture map coordinates to compute TBN matrix
v0 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 0], :]
v1 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 1], :]
v2 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 2], :]
vt0 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 0], :]
vt1 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 1], :]
vt2 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 2], :]
primposmesh = (
self.barim[None, stridey//2::stridey, stridex//2::stridex, 0, None] * v0 +
self.barim[None, stridey//2::stridey, stridex//2::stridex, 1, None] * v1 +
self.barim[None, stridey//2::stridey, stridex//2::stridex, 2, None] * v2
).view(v0.size(0), self.nprims, 3) / self.volradius
# compute TBN matrix
primrotmesh = compute_tbn(v0, v1, v2, vt0, vt1, vt2)
# decode motion deltas
primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding)
if trainiter <= self.postrainstart:
primposdelta = primposdelta * 0.
primrvecdelta = primrvecdelta * 0.
primscaledelta = primscaledelta * 0.
# compose mesh transform with deltas
primpos = primposmesh + primposdelta * 0.01
primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.01)
primrot = torch.bmm(
primrotmesh.view(-1, 3, 3),
primrotdelta.view(-1, 3, 3)).view(encoding.size(0), self.nprims, 3, 3)
primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * torch.exp(primscaledelta * 0.01)
primtransf = None
else:
geo = None
# decode motion deltas
primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding)
if trainiter <= self.postrainstart:
primposdelta = primposdelta * 0.
primrvecdelta = primrvecdelta * 0.
primscaledelta = primscaledelta * 0. + 1.
primpos = primposdelta * 0.3
primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.3)
primrot = torch.exp(primrotdelta * 0.01)
primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * primscaledelta
primtransf = None
# options
algo = renderoptions.get("algo")
chlast = renderoptions.get("chlast")
half = renderoptions.get("half")
if self.rgbadec is not None:
# shared rgb and alpha branch
scale = torch.tensor([25., 25., 25., 1.], device=encoding.device)
bias = torch.tensor([100., 100., 100., 0.], device=encoding.device)
if chlast is not None and bool(chlast):
scale = scale[None, None, None, None, None, :]
bias = bias[None, None, None, None, None, :]
else:
scale = scale[None, None, :, None, None, None]
bias = bias[None, None, :, None, None, None]
templatein = torch.cat([encoding, viewdirs], dim=1)
if half is not None and bool(half):
templatein = templatein.half()
template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
template = bias + scale * template
if not self.notplateact:
template = F.relu(template)
if half is not None and bool(half):
template = template.float()
else:
templatein = torch.cat([encoding, viewdirs], dim=1)
if half is not None and bool(half):
templatein = templatein.half()
primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions)
primrgb = primrgb * 25. + 100.
if not self.notplateact:
primrgb = F.relu(primrgb)
templatein = encoding
if half is not None and bool(half):
templatein = templatein.half()
primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
if not self.notplateact:
primalpha = F.relu(primalpha)
if trainiter <= self.alphatrainstart:
primalpha = primalpha * 0. + 1.
if algo is not None and int(algo) == 4:
template = torch.cat([primrgb, primalpha], dim=-1)
elif chlast is not None and bool(chlast):
template = torch.cat([primrgb, primalpha], dim=-1)
else:
template = torch.cat([primrgb, primalpha], dim=2)
if half is not None and bool(half):
template = template.float()
if self.warpdec is not None:
warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01
warp = warp + torch.stack(torch.meshgrid(
torch.linspace(-1., 1., self.primsize[2], device=encoding.device),
torch.linspace(-1., 1., self.primsize[1], device=encoding.device),
torch.linspace(-1., 1., self.primsize[0], device=encoding.device))[::-1],
dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :]
else:
warp = None
# debugging / visualization
viewaxes = renderoptions.get("viewaxes")
colorprims = renderoptions.get("colorprims")
viewslab = renderoptions.get("viewslab")
# add axes to primitives
if viewaxes is not None and bool(viewaxes):
template[:, :, 3, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550.
template[:, :, 0, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550.
template[:, :, 3, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550.
template[:, :, 1, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550.
template[:, :, 3, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550.
template[:, :, 2, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550.
# give each primitive a unique color
if colorprims is not None and bool(colorprims):
lightdir = -torch.tensor([1., 1., 1.], device=template.device)
lightdir = lightdir / torch.sqrt(torch.sum(lightdir ** 2))
zz, yy, xx = torch.meshgrid(
torch.linspace(-1., 1., self.primsize[2], device=template.device),
torch.linspace(-1., 1., self.primsize[1], device=template.device),
torch.linspace(-1., 1., self.primsize[0], device=template.device))
primnormalx = torch.where(
(torch.abs(xx) >= torch.abs(yy)) & (torch.abs(xx) >= torch.abs(zz)),
torch.sign(xx) * torch.ones_like(xx),
torch.zeros_like(xx))
primnormaly = torch.where(
(torch.abs(yy) >= torch.abs(xx)) & (torch.abs(yy) >= torch.abs(zz)),
torch.sign(yy) * torch.ones_like(xx),
torch.zeros_like(xx))
primnormalz = torch.where(
(torch.abs(zz) >= torch.abs(xx)) & (torch.abs(zz) >= torch.abs(yy)),
torch.sign(zz) * torch.ones_like(xx),
torch.zeros_like(xx))
primnormal = torch.stack([primnormalx, primnormaly, primnormalz], dim=-1)
primnormal = F.normalize(primnormal, dim=-1)
torch.manual_seed(123456)
gridz, gridy, gridx = torch.meshgrid(
torch.linspace(-1., 1., self.primsize[2], device=encoding.device),
torch.linspace(-1., 1., self.primsize[1], device=encoding.device),
torch.linspace(-1., 1., self.primsize[0], device=encoding.device))
grid = torch.stack([gridx, gridy, gridz], dim=-1)
if chlast is not None and chlast:
template[:] = torch.rand(1, template.size(1), 1, 1, 1, template.size(-1), device=template.device) * 255.
template[:, :, :, :, :, 3] = 1000.
else:
template[:] = torch.rand(1, template.size(1), template.size(2), 1, 1, 1, device=template.device) * 255.
template[:, :, 3, :, :, :] = 1000.
if chlast is not None and chlast:
lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2)
template[:, :, :, :, :, :3] *= 1.2 * torch.sum(
lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, :, :, :, None].clamp(min=0.05)
else:
lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2)
template[:, :, :3, :, :, :] *= 1.2 * torch.sum(
lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, None, :, :, :].clamp(min=0.05)
# view slab as a 2d grid
if viewslab is not None and bool(viewslab):
assert evaliter is not None
yy, xx = torch.meshgrid(
torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device),
torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device))
primpos0 = torch.stack([xx*1.5, 0.75-yy*1.5, xx*0.+0.5], dim=-1)[None, :, :, :].repeat(template.size(0), 1, 1, 1).view(-1, self.nprims, 3)
primrot0 = torch.eye(3, device=template.device)[None, None, :, :].repeat(template.size(0), self.nprims, 1, 1)
primrot0.data[:, :, 1, 1] *= -1.
primscale0 = torch.ones((template.size(0), self.nprims, 3), device=template.device) * math.sqrt(self.nprims) * 1.25 #* 0.5
blend = ((evaliter - 256.) / 64.).clamp(min=0., max=1.)[:, None, None]
blend = 3 * blend ** 2 - 2 * blend ** 3
primpos = (1. - blend) * primpos0 + blend * primpos
primrot = models.utils.rotation_interp(primrot0, primrot, blend)
primscale = torch.exp((1. - blend) * torch.log(primscale0) + blend * torch.log(primscale))
losses = {}
# prior on primitive volume
if "primvolsum" in losslist:
losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1)
if "logprimscalevar" in losslist:
logprimscale = torch.log(primscale)
logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True)
losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2)
result = {
"template": template,
"primpos": primpos,
"primrot": primrot,
"primscale": primscale}
if primtransf is not None:
result["primtransf"] = primtransf
if warp is not None:
result["warp"] = warp
if geo is not None:
result["verts"] = geo
return result, losses