in extensions/mvpraymarch/mvpraymarch.py [0:0]
def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False,
dowarp=False, chlast=False, fadescale=8., fadeexp=8.,
accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)):
N = 2
H = 65
W = 65
k3 = 4
K = k3*k3*k3
M = 32
print("=================================================================")
print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, "
"fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, "
"blocksize={}, bwdblocksize={}".format(
usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast,
fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize,
bwdblocksize))
# generate random inputs
torch.manual_seed(1112)
coherent_rays = True
if not coherent_rays:
_raypos = torch.randn(N, H, W, 3).to("cuda")
_raydir = torch.randn(N, H, W, 3).to("cuda")
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
else:
focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)])
princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)])
pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float())
pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
_raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda")
_raydir = raydir.to("cuda")
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
max_len = 6.0
_stepsize = max_len / 15.386928
_tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \
torch.rand(N, H, W, 2, device="cuda") * 1.
_template = torch.randn(N, K, 4, M, M, M, requires_grad=True)
_template.data[:, :, -1, :, :, :] -= 3.5
_template = _template.contiguous().detach().clone()
_template.requires_grad = True
gridxyz = torch.stack(torch.meshgrid(
torch.linspace(-1., 1., M//2),
torch.linspace(-1., 1., M//2),
torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous()
_warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone()
_warp.requires_grad = True
_primpos = torch.randn(N, K, 3, requires_grad=True)
_primpos = torch.randn(N, K, 3, requires_grad=True)
coherent_centers = True
if coherent_centers:
ns = k3
#assert ns*ns*ns==K
grid3d = torch.stack(torch.meshgrid(
torch.linspace(-1., 1., ns),
torch.linspace(-1., 1., ns),
torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None]
_primpos = ((
grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) +
0.1 * torch.randn(N, K, 3, requires_grad=True)
)).contiguous().detach().clone()
_primpos.requires_grad = True
scale_ws = 1.
_primrot = torch.randn(N, K, 3)
rodrigues = Rodrigues()
_primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone()
_primrot.requires_grad = True
_primscale = torch.randn(N, K, 3, requires_grad=True)
_primscale.data *= 0.0
if dowarp:
params = [_template, _warp, _primscale, _primrot, _primpos]
paramnames = ["template", "warp", "primscale", "primrot", "primpos"]
else:
params = [_template, _primscale, _primrot, _primpos]
paramnames = ["template", "primscale", "primrot", "primpos"]
termthreshorig = termthresh
########################### run pytorch version ###########################
raypos = _raypos
raydir = _raydir
stepsize = _stepsize
tminmax = _tminmax
#template = F.softplus(_template.to("cuda") * 1.5)
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
warp = _warp.to("cuda")
primpos = _primpos.to("cuda") * 0.3
primrot = _primrot.to("cuda")
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
# python raymarching implementation
rayrgba = torch.zeros((N, H, W, 4)).to("cuda")
raypos = raypos + raydir * tminmax[:, :, :, 0, None]
t = tminmax[:, :, :, 0]
step = 0
t0 = t.detach().clone()
raypos0 = raypos.detach().clone()
torch.cuda.synchronize()
time0 = time.time()
while (t < tminmax[:, :, :, 1]).any():
valid2 = torch.ones_like(rayrgba[:, :, :, 3:4])
for k in range(K):
y0 = torch.bmm(
(raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)),
primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :]
fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True))
if dowarp:
y1 = F.grid_sample(
warp[:, k, :, :, :, :],
y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
else:
y1 = y0
sample = F.grid_sample(
template[:, k, :, :, :, :],
y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
valid1 = (
torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) *
torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True))
valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None]
alpha0 = sample[:, :, :, 3:4]
rgb = sample[:, :, :, 0:3] * valid * valid1
alpha = alpha0 * fade * stepsize * valid * valid1
if accum == 0:
newalpha = rayrgba[:, :, :, 3:4] + alpha
contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1
rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1)
else:
raise
step += 1
t = t0 + stepsize * step
raypos = raypos0 + raydir * stepsize * step
print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item())
sample0 = rayrgba
torch.cuda.synchronize()
time1 = time.time()
sample0.backward(torch.ones_like(sample0))
torch.cuda.synchronize()
time2 = time.time()
print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total"))
print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0))
grads0 = [p.grad.detach().clone() for p in params]
for p in params:
p.grad.detach_()
p.grad.zero_()
############################## run cuda version ###########################
raypos = _raypos
raydir = _raydir
stepsize = _stepsize
tminmax = _tminmax
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
warp = _warp.to("cuda")
if chlast:
template = template.permute(0, 1, 3, 4, 5, 2).contiguous()
warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous()
primpos = _primpos.to("cuda") * 0.3
primrot = _primrot.to("cuda")
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
niter = 1
tf, tb = 0., 0.
for i in range(niter):
for p in params:
try:
p.grad.detach_()
p.grad.zero_()
except:
pass
t0 = time.time()
torch.cuda.synchronize()
sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax,
(primpos, primrot, primscale),
template, warp if dowarp else None,
algo=algo, usebvh=usebvh, sortprims=sortprims,
maxhitboxes=maxhitboxes, synchitboxes=synchitboxes,
chlast=chlast, fadescale=fadescale, fadeexp=fadeexp,
accum=accum, termthresh=termthreshorig,
griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize)
t1 = time.time()
torch.cuda.synchronize()
sample1.backward(torch.ones_like(sample1), retain_graph=True)
torch.cuda.synchronize()
t2 = time.time()
tf += t1 - t0
tb += t2 - t1
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
grads1 = [p.grad.detach().clone() for p in params]
############# compare results #############
print("-----------------------------------------------------------------")
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda"))
ind = torch.argmax(torch.abs(sample0 - sample1))
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
"fwd",
torch.max(torch.abs(sample0 - sample1)).item(),
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
torch.sqrt(torch.sum(sample0 * sample0)).item(),
torch.sqrt(torch.sum(sample1 * sample1)).item(),
ind.item(),
sample0.view(-1)[ind].item(),
sample1.view(-1)[ind].item()))
for p, g0, g1 in zip(paramnames, grads0, grads1):
ind = torch.argmax(torch.abs(g0 - g1))
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
p,
torch.max(torch.abs(g0 - g1)).item(),
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
torch.sqrt(torch.sum(g0 * g0)).item(),
torch.sqrt(torch.sum(g1 * g1)).item(),
ind.item(),
g0.view(-1)[ind].item(),
g1.view(-1)[ind].item()))