def forward()

in models/decoders/nv.py [0:0]


    def forward(self,
            encoding,
            viewpos,
            condinput : Optional[torch.Tensor]=None,
            renderoptions : Optional[Dict[str, str]]=None,
            trainiter : int=-1,
            evaliter : Optional[torch.Tensor]=None,
            losslist : Optional[List[str]]=None,
            modelmatrix : Optional[torch.Tensor]=None):
        """
        Parameters
        ----------
        encoding : torch.Tensor [B, 256]
            Encoding of current frame
        viewpos : torch.Tensor [B, 3]
            Viewing position of target camera view
        condinput : torch.Tensor [B, ?]
            Additional conditioning input (e.g., headpose)
        renderoptions : dict
            Options for rendering (e.g., rendering debug images)
        trainiter : int,
            Current training iteration
        losslist : list,
            List of losses to compute and return

        Returns
        -------
        result : dict,
            Contains predicted vertex positions, primitive contents and
            locations, scaling, and orientation, and any losses.
        """
        assert renderoptions is not None
        assert losslist is not None

        if condinput is not None:
            encoding = torch.cat([encoding, condinput], dim=1)

        encoding = self.enc(encoding)

        viewdirs = F.normalize(viewpos, dim=1)

        primpos = torch.zeros(encoding.size(0), 1, 3, device=encoding.device)
        primrot = torch.eye(3, device=encoding.device)[None, None, :, :].repeat(encoding.size(0), 1, 1, 1)
        primscale = torch.ones(encoding.size(0), 1, 3, device=encoding.device)

        # options
        algo = renderoptions.get("algo")
        chlast = renderoptions.get("chlast")
        half = renderoptions.get("half")

        if self.rgbadec is not None:
            # shared rgb and alpha branch
            scale = torch.tensor([25., 25., 25., 1.], device=encoding.device)
            bias = torch.tensor([100., 100., 100., 0.], device=encoding.device)
            if chlast is not None and bool(chlast):
                scale = scale[None, None, None, None, None, :]
                bias = bias[None, None, None, None, None, :]
            else:
                scale = scale[None, None, :, None, None, None]
                bias = bias[None, None, :, None, None, None]

            templatein = torch.cat([encoding, viewdirs], dim=1)
            if half is not None and bool(half):
                templatein = templatein.half()
            template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            template = bias + scale * template
            if not self.notplateact:
                template = F.relu(template)
            if half is not None and bool(half):
                template = template.float()
        else:
            templatein = torch.cat([encoding, viewdirs], dim=1)
            if half is not None and bool(half):
                templatein = templatein.half()
            primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            primrgb = primrgb * 25. + 100.
            if not self.notplateact:
                primrgb = F.relu(primrgb)

            templatein = encoding
            if half is not None and bool(half):
                templatein = templatein.half()
            primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions)
            if not self.notplateact:
                primalpha = F.relu(primalpha)

            if trainiter <= self.alphatrainstart:
                primalpha = primalpha * 0. + 1.
        
            if algo is not None and int(algo) == 4:
                template = torch.cat([primrgb, primalpha], dim=-1)
            elif chlast is not None and bool(chlast):
                template = torch.cat([primrgb, primalpha], dim=-1)
            else:
                template = torch.cat([primrgb, primalpha], dim=2)
            if half is not None and bool(half):
                template = template.float()

        if self.warpdec is not None:
            warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01
            warp = warp + torch.stack(torch.meshgrid(
                torch.linspace(-1., 1., self.warpprimsize, device=encoding.device),
                torch.linspace(-1., 1., self.warpprimsize, device=encoding.device),
                torch.linspace(-1., 1., self.warpprimsize, device=encoding.device))[::-1],
                dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :]
            warp = warp.contiguous()
        else:
            warp = None

        losses = {}

        # prior on primitive volume
        if "primvolsum" in losslist:
            losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1)

        if "logprimscalevar" in losslist:
            logprimscale = torch.log(primscale)
            logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True)
            losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2)

        result = {
                "template": template,
                "primpos": primpos,
                "primrot": primrot,
                "primscale": primscale}
        if warp is not None:
            result["warp"] = warp
        return result, losses