in extensions/mvpraymarch/mvpraymarch.py [0:0]
def forward(self, raypos, raydir, stepsize, tminmax,
primpos, primrot, primscale,
template, warp,
rayterm, gradmode, options):
algo = options["algo"]
usebvh = options["usebvh"]
sortprims = options["sortprims"]
randomorder = options["randomorder"]
maxhitboxes = options["maxhitboxes"]
synchitboxes = options["synchitboxes"]
chlast = options["chlast"]
fadescale = options["fadescale"]
fadeexp = options["fadeexp"]
accum = options["accum"]
termthresh = options["termthresh"]
griddim = options["griddim"]
if isinstance(options["blocksize"], tuple):
blocksizex, blocksizey = options["blocksize"]
else:
blocksizex = options["blocksize"]
blocksizey = 1
assert raypos.is_contiguous() and raypos.size(3) == 3
assert raydir.is_contiguous() and raydir.size(3) == 3
assert tminmax.is_contiguous() and tminmax.size(3) == 2
assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3
assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3
assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3
if chlast:
assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4
assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3)
else:
assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4
assert warp is None or (warp.is_contiguous() and warp.size(2) == 3)
primtransfin = (primpos, primrot, primscale)
# Build bvh
if usebvh is not False:
# compute radius of primitives
sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin,
algo, fixedorder=usebvh=="fixedorder")
assert sortedobjid.is_contiguous()
assert nodechildren.is_contiguous()
assert nodeaabb.is_contiguous()
if randomorder:
sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))]
else:
_, sortedobjid, nodechildren, nodeaabb = None, None, None, None
# march through boxes
N, H, W = raypos.size(0), raypos.size(1), raypos.size(2)
rayrgba = torch.empty((N, H, W, 4), device=raypos.device)
if gradmode:
raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device)
rayterm = None
else:
raysat = None
rayterm = None
mvpraymarchlib.raymarch_forward(
raypos, raydir, stepsize, tminmax,
sortedobjid, nodechildren, nodeaabb,
*primtransfin,
template, warp,
rayrgba, raysat, rayterm,
algo, sortprims, maxhitboxes, synchitboxes, chlast,
fadescale, fadeexp,
accum, termthresh,
griddim, blocksizex, blocksizey)
self.save_for_backward(
raypos, raydir, tminmax,
sortedobjid, nodechildren, nodeaabb,
primpos, primrot, primscale,
template, warp,
rayrgba, raysat, rayterm)
self.options = options
self.stepsize = stepsize
return rayrgba