def forward()

in models/volumetric.py [0:0]


    def forward(self,
            camrot : torch.Tensor,
            campos : torch.Tensor,
            focal : torch.Tensor,
            princpt : torch.Tensor,
            camindex : Optional[torch.Tensor] = None,
            pixelcoords : Optional[torch.Tensor]=None,
            modelmatrix : Optional[torch.Tensor]=None,
            modelmatrixinv : Optional[torch.Tensor]=None,
            modelmatrix_next : Optional[torch.Tensor]=None,
            modelmatrixinv_next : Optional[torch.Tensor]=None,
            validinput : Optional[torch.Tensor]=None,
            avgtex : Optional[torch.Tensor]=None,
            avgtex_next : Optional[torch.Tensor]=None,
            verts : Optional[torch.Tensor]=None,
            verts_next : Optional[torch.Tensor]=None,
            fixedcamimage : Optional[torch.Tensor]=None,
            encoding : Optional[torch.Tensor]=None,
            image : Optional[torch.Tensor]=None,
            imagemask : Optional[torch.Tensor]=None,
            imagevalid : Optional[torch.Tensor]=None,
            bg : Optional[torch.Tensor]=None,
            renderoptions : dict ={},
            trainiter : int=-1,
            evaliter : Optional[torch.Tensor]=None,
            outputlist : list=[],
            losslist : list=[],
            **kwargs):
        """
        Parameters
        ----------
        camrot : torch.Tensor [B, 3, 3]
            Rotation matrix of target view camera
        campos : torch.Tensor [B, 3]
            Position of target view camera
        focal : torch.Tensor [B, 2]
            Focal length of target view camera
        princpt : torch.Tensor [B, 2]
            Princple point of target view camera
        camindex : torch.Tensor[int32], optional [B]
            Camera index within the list of all cameras
        pixelcoords : torch.Tensor, optional [B, H', W', 2]
            Pixel coordinates to render of the target view camera
        modelmatrix : torch.Tensor, optional [B, 3, 3]
            Relative transform from the 'neutral' pose of object
        validinput : torch.Tensor, optional [B]
            Whether the current batch element is valid (used for missing images)
        avgtex : torch.Tensor, optional [B, 3, 1024, 1024]
            Texture map averaged from all viewpoints
        verts : torch.Tensor, optional [B, 7306, 3]
            Mesh vertex positions
        fixedcamimage : torch.Tensor, optional [B, 3, 512, 334]
            Camera images from a one or more cameras that are always the same
            (i.e., unrelated to target)
        encoding : torch.Tensor, optional [B, 256]
            Direct encodings (overrides encoder)
        image : torch.Tensor, optional [B, 3, H, W]
            Target image
        imagemask : torch.Tensor, optional [B, 1, H, W]
            Target image mask for computing reconstruction loss
        imagevalid : torch.Tensor, optional [B]
        bg : torch.Tensor, optional [B, 3, H, W]
        renderoptions : dict
            Rendering/raymarching options (e.g., stepsize, whether to output debug images, etc.)
        trainiter : int
            Training iteration number
        outputlist : list
            Values to return (e.g., image reconstruction, debug output)
        losslist : list
            Losses to output (e.g., image reconstruction loss, priors)

        Returns
        -------
        result : dict
            Contains outputs specified in outputlist (e.g., image rgb
            reconstruction "irgbrec")
        losses : dict
            Losses to optimize
        """
        resultout = {}
        resultlosses = {}

        aestart = time.time()

        # encode/get encoding
        if encoding is None:
            if "enctime" in outputlist:
                torch.cuda.synchronize()
                encstart = time.time()
            encout, enclosses = self.encoder(
                    *[dict(verts=verts, avgtex=avgtex, fixedcamimage=fixedcamimage)[k] for k in self.encoderinputs],
                    losslist=losslist)
            if "enctime" in outputlist:
                torch.cuda.synchronize()
                encend = time.time()
                resultout["enctime"] = encend - encstart

            encoding = encout["encoding"]
            resultlosses.update(enclosses)

        # compute relative viewing position
        if modelmatrixinv is not None:
            viewrot = torch.bmm(camrot, modelmatrixinv[:, :3, :3])
            viewpos = torch.bmm((campos[:, :] - modelmatrixinv[:, :3, 3])[:, None, :], modelmatrixinv[:, :3, :3])[:, 0, :]
        else:
            viewrot = camrot
            viewpos = campos

        # decode volumetric representation
        if "dectime" in outputlist:
            torch.cuda.synchronize()
            decstart = time.time()
        if isinstance(self.decoder, torch.jit.ScriptModule):
            # torchscript requires statically typed dict
            renderoptionstyped : Dict[str, str] = {k: str(v) for k, v in renderoptions.items()}
        else:
            renderoptionstyped = renderoptions
        decout, declosses = self.decoder(
                encoding,
                viewpos,
                renderoptions=renderoptionstyped,
                trainiter=trainiter,
                evaliter=evaliter,
                losslist=losslist)
        if "dectime" in outputlist:
            torch.cuda.synchronize()
            decend = time.time()
            resultout["dectime"] = decend - decstart
        resultlosses.update(declosses)

        # compute vertex loss
        if "vertmse" in losslist:
            weight = validinput[:, None, None].expand_as(verts)

            if hasattr(self, "vertmask"):
                weight = weight * self.vertmask[None, :, None]

            vertsrecstd = (decout["verts"] - self.vertmean) / self.vertstd

            vertsqerr = weight * (verts - vertsrecstd) ** 2

            vertmse = torch.sum(vertsqerr.view(vertsqerr.size(0), -1), dim=-1)
            vertmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)

            resultlosses["vertmse"] = (vertmse, vertmse_weight)

        # compute texture loss
        if "trgbmse" in losslist or "trgbsqerr" in outputlist:
            weight = (validinput[:, None, None, None] * texmask[:, None, :, :].float()).expand_as(tex).contiguous()

            # re-standardize
            texrecstd = (decout["tex"] - self.texmean.to("cuda")) / self.texstd
            texstd = (tex - self.texmean.to("cuda")) / self.texstd

            texsqerr = weight * (texstd - texrecstd) ** 2

            if "trgbsqerr" in outputlist:
                resultout["trgbsqerr"] = texsqerr

            # texture rgb mean-squared-error
            if "trgbmse" in losslist:
                texmse = torch.sum(texsqerr.view(texsqerr.size(0), -1), dim=-1)
                texmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)

                resultlosses["trgbmse"] = (texmse, texmse_weight)

        # subsample depth, imagerec, imagerecmask
        if image is not None and pixelcoords.size()[1:3] != image.size()[2:4]:
            imagesize = torch.tensor(image.size()[3:1:-1], dtype=torch.float32, device=pixelcoords.device)
        else:
            imagesize = torch.tensor(pixelcoords.size()[2:0:-1], dtype=torch.float32, device=pixelcoords.device)

        samplecoords = pixelcoords * 2. / (imagesize[None, None, None, :] - 1.) - 1.

        # compute ray directions
        if self.cudaraydirs:
            raypos, raydir, tminmax = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, self.volradius)
        else:
            raydir = compute_raydirs_ref(pixelcoords, viewrot, focal, princpt)
            raypos, tminmax = compute_rmbounds(viewpos, raydir, self.volradius)

        if "dtstd" in renderoptions:
            renderoptions["dt"] = renderoptions["dt"] * \
                    torch.exp(torch.randn(1) * renderoptions.get("dtstd")).item()

        if renderoptions.get("unbiastminmax", False):
            stepsize = renderoptions["dt"] / self.volradius
            tminmax = torch.floor(tminmax / stepsize) * stepsize

        if renderoptions.get("tminmaxblocks", False):
            bx, by = renderoptions.get("blocksize", (8, 16))
            H, W = tminmax.size(1), tminmax.size(2)
            tminmax = tminmax.view(tminmax.size(0), H // by, by, W // bx, bx, 2)
            tminmax = tminmax.amin(dim=[2, 4], keepdim=True)
            tminmax = tminmax.repeat(1, 1, by, 1, bx, 1)
            tminmax = tminmax.view(tminmax.size(0), H, W, 2)

        # raymarch
        if "rmtime" in outputlist:
            torch.cuda.synchronize()
            rmstart = time.time()
        rayrgba, rmlosses = self.raymarcher(raypos, raydir, tminmax,
                decout=decout, renderoptions=renderoptions,
                trainiter=trainiter, evaliter=evaliter, losslist=losslist)
        resultlosses.update(rmlosses)
        if "rmtime" in outputlist:
            torch.cuda.synchronize()
            rmend = time.time()
            resultout["rmtime"] = rmend - rmstart

        if isinstance(rayrgba, tuple):
            rayrgb, rayalpha = rayrgba
        else:
            rayrgb, rayalpha = rayrgba[:, :3, :, :].contiguous(), rayrgba[:, 3:4, :, :].contiguous()

        # beta distribution prior on final opacity
        if "alphapr" in losslist:
            alphaprior = torch.mean(
                    torch.log(0.1 + rayalpha.view(rayalpha.size(0), -1)) +
                    torch.log(0.1 + 1. - rayalpha.view(rayalpha.size(0), -1)) - -2.20727, dim=-1)
            resultlosses["alphapr"] = alphaprior

        # color correction
        if camindex is not None and not renderoptions.get("nocolcorrect", False):
            rayrgb = self.colorcal(rayrgb, camindex)

        # background decoder
        if self.bgmodel is not None and not renderoptions.get("nobg", False):
            if "bgtime" in outputlist:
                torch.cuda.synchronize()
                bgstart = time.time()

            raypos, raydir, tminmax = compute_raydirs(campos, camrot, focal, princpt, pixelcoords, self.volradius)

            rayposbeg = raypos + raydir * tminmax[..., 0:1]
            rayposend = raypos + raydir * tminmax[..., 1:2]

            bg = self.bgmodel(bg, camindex, campos, rayposend, raydir, samplecoords, trainiter=trainiter)

        # alpha matting
        if bg is not None:
            rayrgb = rayrgb + (1. - rayalpha) * bg

            if "bg" in outputlist:
                resultout["bg"] = bg

            if "bgtime" in outputlist:
                torch.cuda.synchronize()
                bgend = time.time()
                resultout["bgtime"] = bgend - bgstart

        if "irgbrec" in outputlist:
            resultout["irgbrec"] = rayrgb
        if "irgbarec" in outputlist:
            resultout["irgbarec"] = torch.cat([rayrgb, rayalpha], dim=1)
        if "irgbflip" in outputlist:
            resultout["irgbflip"] = torch.cat([rayrgb[i:i+1] if i % 4 < 2 else image[i:i+1]
                for i in range(image.size(0))], dim=0)

        # image rgb loss
        if image is not None and trainiter > self.irgbmsestart:
            # subsample image
            if pixelcoords.size()[1:3] != image.size()[2:4]:
                image = F.grid_sample(image, samplecoords, align_corners=True)
                if imagemask is not None:
                    imagemask = F.grid_sample(imagemask, samplecoords, align_corners=True)

            # compute reconstruction loss weighting
            weight = torch.ones_like(image) * validinput[:, None, None, None]
            if imagevalid is not None:
                weight = weight * imagevalid[:, None, None, None]
            if imagemask is not None:
                weight = weight * imagemask

            if "irgbsqerr" in outputlist:
                irgbsqerr_nonorm = (weight * (image - rayrgb) ** 2).contiguous()
                resultout["irgbsqerr"] = torch.sqrt(irgbsqerr_nonorm.mean(dim=1, keepdim=True))

            # standardize
            rayrgb = (rayrgb - self.imagemean) / self.imagestd
            image = (image - self.imagemean) / self.imagestd

            irgbsqerr = (weight * (image - rayrgb) ** 2).contiguous()

            if "irgbmse" in losslist:
                irgbmse = torch.sum(irgbsqerr.view(irgbsqerr.size(0), -1), dim=-1)
                irgbmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)

                resultlosses["irgbmse"] = (irgbmse, irgbmse_weight)

        aeend = time.time()
        if "aetime" in outputlist:
            resultout["aetime"] = aeend - aestart

        return resultout, resultlosses