def forward()

in models/decoders/mvp.py [0:0]


    def forward(self,
            encoding,
            viewpos,
            condinput : Optional[torch.Tensor]=None,
            renderoptions : Optional[Dict[str, str]]=None,
            trainiter : int=-1,
            evaliter : Optional[torch.Tensor]=None,
            losslist : Optional[List[str]]=None,
            modelmatrix : Optional[torch.Tensor]=None):
        """
        Parameters
        ----------
        encoding : torch.Tensor [B, 256]
            Encoding of current frame
        viewpos : torch.Tensor [B, 3]
            Viewing position of target camera view
        condinput : torch.Tensor [B, ?]
            Additional conditioning input (e.g., headpose)
        renderoptions : dict
            Options for rendering (e.g., rendering debug images)
        trainiter : int,
            Current training iteration
        losslist : list,
            List of losses to compute and return

        Returns
        -------
        result : dict,
            Contains predicted vertex positions, primitive contents and
            locations, scaling, and orientation, and any losses.
        """
        assert renderoptions is not None
        assert losslist is not None

        if condinput is not None:
            encoding = torch.cat([encoding, condinput], dim=1)

        encoding = self.enc(encoding)

        viewdirs = F.normalize(viewpos, dim=1)

        if int(math.sqrt(self.nprims)) ** 2 == self.nprims:
            nprimsy = int(math.sqrt(self.nprims))
        else:
            nprimsy = int(math.sqrt(self.nprims // 2))
        nprimsx = self.nprims // nprimsy

        assert nprimsx * nprimsy == self.nprims

        if not self.nogeo:
            # decode mesh vertices
            geo = self.geobranch(encoding)
            geo = geo.view(encoding.size(0), -1, 3)
            geo = geo * self.vertstd + self.vertmean

            # placement of primitives on mesh
            uvheight, uvwidth = self.barim.size(0), self.barim.size(1)
            stridey = uvheight // nprimsy
            stridex = uvwidth // nprimsx

            # get subset of vertices and texture map coordinates to compute TBN matrix
            v0 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 0], :]
            v1 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 1], :]
            v2 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 2], :]

            vt0 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 0], :]
            vt1 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 1], :]
            vt2 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 2], :]

            primposmesh = (
                    self.barim[None, stridey//2::stridey, stridex//2::stridex, 0, None] * v0 +
                    self.barim[None, stridey//2::stridey, stridex//2::stridex, 1, None] * v1 +
                    self.barim[None, stridey//2::stridey, stridex//2::stridex, 2, None] * v2
                    ).view(v0.size(0), self.nprims, 3) / self.volradius

            # compute TBN matrix
            primrotmesh = compute_tbn(v0, v1, v2, vt0, vt1, vt2)

            # decode motion deltas
            primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding)
            if trainiter <= self.postrainstart:
                primposdelta = primposdelta * 0.
                primrvecdelta = primrvecdelta * 0.
                primscaledelta = primscaledelta * 0.

            # compose mesh transform with deltas
            primpos = primposmesh + primposdelta * 0.01
            primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.01)
            primrot = torch.bmm(
                    primrotmesh.view(-1, 3, 3),
                    primrotdelta.view(-1, 3, 3)).view(encoding.size(0), self.nprims, 3, 3)
            primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * torch.exp(primscaledelta * 0.01)

            primtransf = None
        else:
            geo = None

            # decode motion deltas
            primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding)
            if trainiter <= self.postrainstart:
                primposdelta = primposdelta * 0.
                primrvecdelta = primrvecdelta * 0.
                primscaledelta = primscaledelta * 0. + 1.

            primpos = primposdelta * 0.3
            primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.3)
            primrot = torch.exp(primrotdelta * 0.01)
            primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * primscaledelta

            primtransf = None

        # options
        algo = renderoptions.get("algo")
        chlast = renderoptions.get("chlast")
        half = renderoptions.get("half")

        if self.rgbadec is not None:
            # shared rgb and alpha branch
            scale = torch.tensor([25., 25., 25., 1.], device=encoding.device)
            bias = torch.tensor([100., 100., 100., 0.], device=encoding.device)
            if chlast is not None and bool(chlast):
                scale = scale[None, None, None, None, None, :]
                bias = bias[None, None, None, None, None, :]
            else:
                scale = scale[None, None, :, None, None, None]
                bias = bias[None, None, :, None, None, None]

            templatein = torch.cat([encoding, viewdirs], dim=1)
            if half is not None and bool(half):
                templatein = templatein.half()
            template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            template = bias + scale * template
            if not self.notplateact:
                template = F.relu(template)
            if half is not None and bool(half):
                template = template.float()
        else:
            templatein = torch.cat([encoding, viewdirs], dim=1)
            if half is not None and bool(half):
                templatein = templatein.half()
            primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            primrgb = primrgb * 25. + 100.
            if not self.notplateact:
                primrgb = F.relu(primrgb)

            templatein = encoding
            if half is not None and bool(half):
                templatein = templatein.half()
            primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            if not self.notplateact:
                primalpha = F.relu(primalpha)

            if trainiter <= self.alphatrainstart:
                primalpha = primalpha * 0. + 1.
        
            if algo is not None and int(algo) == 4:
                template = torch.cat([primrgb, primalpha], dim=-1)
            elif chlast is not None and bool(chlast):
                template = torch.cat([primrgb, primalpha], dim=-1)
            else:
                template = torch.cat([primrgb, primalpha], dim=2)
            if half is not None and bool(half):
                template = template.float()

        if self.warpdec is not None:
            warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01
            warp = warp + torch.stack(torch.meshgrid(
                torch.linspace(-1., 1., self.primsize[2], device=encoding.device),
                torch.linspace(-1., 1., self.primsize[1], device=encoding.device),
                torch.linspace(-1., 1., self.primsize[0], device=encoding.device))[::-1],
                dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :]
        else:
            warp = None

        # debugging / visualization
        viewaxes = renderoptions.get("viewaxes")
        colorprims = renderoptions.get("colorprims")
        viewslab = renderoptions.get("viewslab")

        # add axes to primitives
        if viewaxes is not None and bool(viewaxes):
            template[:, :, 3, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550.
            template[:, :, 0, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550.
            template[:, :, 3, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550.
            template[:, :, 1, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550.
            template[:, :, 3, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550.
            template[:, :, 2, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550.

        # give each primitive a unique color
        if colorprims is not None and bool(colorprims):
            lightdir = -torch.tensor([1., 1., 1.], device=template.device)
            lightdir = lightdir / torch.sqrt(torch.sum(lightdir ** 2))
            zz, yy, xx = torch.meshgrid(
                torch.linspace(-1., 1., self.primsize[2], device=template.device),
                torch.linspace(-1., 1., self.primsize[1], device=template.device),
                torch.linspace(-1., 1., self.primsize[0], device=template.device))
            primnormalx = torch.where(
                    (torch.abs(xx) >= torch.abs(yy)) & (torch.abs(xx) >= torch.abs(zz)),
                    torch.sign(xx) * torch.ones_like(xx),
                    torch.zeros_like(xx))
            primnormaly = torch.where(
                    (torch.abs(yy) >= torch.abs(xx)) & (torch.abs(yy) >= torch.abs(zz)),
                    torch.sign(yy) * torch.ones_like(xx),
                    torch.zeros_like(xx))
            primnormalz = torch.where(
                    (torch.abs(zz) >= torch.abs(xx)) & (torch.abs(zz) >= torch.abs(yy)),
                    torch.sign(zz) * torch.ones_like(xx),
                    torch.zeros_like(xx))
            primnormal = torch.stack([primnormalx, primnormaly, primnormalz], dim=-1)
            primnormal = F.normalize(primnormal, dim=-1)

            torch.manual_seed(123456)

            gridz, gridy, gridx = torch.meshgrid(
                    torch.linspace(-1., 1., self.primsize[2], device=encoding.device),
                    torch.linspace(-1., 1., self.primsize[1], device=encoding.device),
                    torch.linspace(-1., 1., self.primsize[0], device=encoding.device))
            grid = torch.stack([gridx, gridy, gridz], dim=-1)

            if chlast is not None and chlast:
                template[:] = torch.rand(1, template.size(1), 1, 1, 1, template.size(-1), device=template.device) * 255.
                template[:, :, :, :, :, 3] = 1000.
            else:
                template[:] = torch.rand(1, template.size(1), template.size(2), 1, 1, 1, device=template.device) * 255.
                template[:, :, 3, :, :, :] = 1000.

            if chlast is not None and chlast:
                lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2)
                template[:, :, :, :, :, :3] *= 1.2 * torch.sum(
                        lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, :, :, :, None].clamp(min=0.05)
            else:
                lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2)
                template[:, :, :3, :, :, :] *= 1.2 * torch.sum(
                        lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, None, :, :, :].clamp(min=0.05)

        # view slab as a 2d grid
        if viewslab is not None and bool(viewslab):
            assert evaliter is not None

            yy, xx = torch.meshgrid(
                    torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device),
                    torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device))
            primpos0 = torch.stack([xx*1.5, 0.75-yy*1.5, xx*0.+0.5], dim=-1)[None, :, :, :].repeat(template.size(0), 1, 1, 1).view(-1, self.nprims, 3)
            primrot0 = torch.eye(3, device=template.device)[None, None, :, :].repeat(template.size(0), self.nprims, 1, 1)
            primrot0.data[:, :, 1, 1] *= -1.
            primscale0 = torch.ones((template.size(0), self.nprims, 3), device=template.device) * math.sqrt(self.nprims) * 1.25 #* 0.5

            blend = ((evaliter - 256.) / 64.).clamp(min=0., max=1.)[:, None, None]
            blend = 3 * blend ** 2 - 2 * blend ** 3

            primpos = (1. - blend) * primpos0 + blend * primpos
            primrot = models.utils.rotation_interp(primrot0, primrot, blend)
            primscale = torch.exp((1. - blend) * torch.log(primscale0) + blend * torch.log(primscale))

        losses = {}

        # prior on primitive volume
        if "primvolsum" in losslist:
            losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1)

        if "logprimscalevar" in losslist:
            logprimscale = torch.log(primscale)
            logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True)
            losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2)

        result = {
                "template": template,
                "primpos": primpos,
                "primrot": primrot,
                "primscale": primscale}
        if primtransf is not None:
            result["primtransf"] = primtransf
        if warp is not None:
            result["warp"] = warp
        if geo is not None:
            result["verts"] = geo
        return result, losses