def forward()

in extensions/mvpraymarch/mvpraymarch.py [0:0]


    def forward(self, raypos, raydir, stepsize, tminmax,
            primpos, primrot, primscale,
            template, warp,
            rayterm, gradmode, options):
        algo = options["algo"]
        usebvh = options["usebvh"]
        sortprims = options["sortprims"]
        randomorder = options["randomorder"]
        maxhitboxes = options["maxhitboxes"]
        synchitboxes = options["synchitboxes"]
        chlast = options["chlast"]
        fadescale = options["fadescale"]
        fadeexp = options["fadeexp"]
        accum = options["accum"]
        termthresh = options["termthresh"]
        griddim = options["griddim"]
        if isinstance(options["blocksize"], tuple):
            blocksizex, blocksizey = options["blocksize"]
        else:
            blocksizex = options["blocksize"]
            blocksizey = 1

        assert raypos.is_contiguous() and raypos.size(3) == 3
        assert raydir.is_contiguous() and raydir.size(3) == 3
        assert tminmax.is_contiguous() and tminmax.size(3) == 2

        assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3
        assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3
        assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3

        if chlast:
            assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4
            assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3)
        else:
            assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4
            assert warp is None or (warp.is_contiguous() and warp.size(2) == 3)

        primtransfin = (primpos, primrot, primscale)

        # Build bvh
        if usebvh is not False:
            # compute radius of primitives
            sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin,
                    algo, fixedorder=usebvh=="fixedorder")
            assert sortedobjid.is_contiguous()
            assert nodechildren.is_contiguous()
            assert nodeaabb.is_contiguous()

            if randomorder:
                sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))]
        else:
            _, sortedobjid, nodechildren, nodeaabb = None, None, None, None

        # march through boxes
        N, H, W = raypos.size(0), raypos.size(1), raypos.size(2)
        rayrgba = torch.empty((N, H, W, 4), device=raypos.device)
        if gradmode:
            raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device)
            rayterm = None
        else:
            raysat = None
            rayterm = None

        mvpraymarchlib.raymarch_forward(
                raypos, raydir, stepsize, tminmax,
                sortedobjid, nodechildren, nodeaabb,
                *primtransfin,
                template, warp,
                rayrgba, raysat, rayterm,
                algo, sortprims, maxhitboxes, synchitboxes, chlast,
                fadescale, fadeexp,
                accum, termthresh,
                griddim, blocksizex, blocksizey)

        self.save_for_backward(
                raypos, raydir, tminmax,
                sortedobjid, nodechildren, nodeaabb,
                primpos, primrot, primscale,
                template, warp,
                rayrgba, raysat, rayterm)
        self.options = options
        self.stepsize = stepsize

        return rayrgba