def gradcheck()

in extensions/mvpraymarch/mvpraymarch.py [0:0]


def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False,
        dowarp=False, chlast=False, fadescale=8., fadeexp=8.,
        accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)):
    N = 2
    H = 65
    W = 65
    k3 = 4
    K = k3*k3*k3

    M = 32

    print("=================================================================")
    print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, "
        "fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, "
        "blocksize={}, bwdblocksize={}".format(
        usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast,
        fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize,
        bwdblocksize))

    # generate random inputs
    torch.manual_seed(1112)

    coherent_rays = True
    if not coherent_rays:
        _raypos = torch.randn(N, H, W, 3).to("cuda")
        _raydir = torch.randn(N, H, W, 3).to("cuda")
        _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
    else:
        focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)])
        princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)])
        pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float())
        pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)

        raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
        raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
        raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))

        _raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda")
        _raydir = raydir.to("cuda")
        _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))

    max_len = 6.0
    _stepsize = max_len / 15.386928
    _tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \
            torch.rand(N, H, W, 2, device="cuda") * 1.

    _template = torch.randn(N, K, 4, M, M, M, requires_grad=True)
    _template.data[:, :, -1, :, :, :] -= 3.5
    _template = _template.contiguous().detach().clone()
    _template.requires_grad = True
    gridxyz = torch.stack(torch.meshgrid(
        torch.linspace(-1., 1., M//2),
        torch.linspace(-1., 1., M//2),
        torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous()
    _warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone()
    _warp.requires_grad = True
    _primpos = torch.randn(N, K, 3, requires_grad=True)
    _primpos = torch.randn(N, K, 3, requires_grad=True)

    coherent_centers = True
    if coherent_centers:
        ns = k3
        #assert ns*ns*ns==K
        grid3d = torch.stack(torch.meshgrid(
            torch.linspace(-1., 1., ns),
            torch.linspace(-1., 1., ns),
            torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None]
        _primpos = ((
            grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) +
            0.1 * torch.randn(N, K, 3, requires_grad=True)
            )).contiguous().detach().clone()
        _primpos.requires_grad = True
    scale_ws = 1.
    _primrot = torch.randn(N, K, 3)
    rodrigues = Rodrigues()
    _primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone()
    _primrot.requires_grad = True

    _primscale = torch.randn(N, K, 3, requires_grad=True)
    _primscale.data *= 0.0

    if dowarp:
        params = [_template, _warp, _primscale, _primrot, _primpos]
        paramnames = ["template", "warp", "primscale", "primrot", "primpos"]
    else:
        params = [_template, _primscale, _primrot, _primpos]
        paramnames = ["template", "primscale", "primrot", "primpos"]

    termthreshorig = termthresh

    ########################### run pytorch version ###########################

    raypos = _raypos
    raydir = _raydir
    stepsize = _stepsize
    tminmax = _tminmax

    #template = F.softplus(_template.to("cuda") * 1.5)
    template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
    warp = _warp.to("cuda")
    primpos = _primpos.to("cuda") * 0.3
    primrot = _primrot.to("cuda")
    primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))

    # python raymarching implementation
    rayrgba = torch.zeros((N, H, W, 4)).to("cuda")
    raypos = raypos + raydir * tminmax[:, :, :, 0, None]
    t = tminmax[:, :, :, 0]

    step = 0
    t0 = t.detach().clone()
    raypos0 = raypos.detach().clone()

    torch.cuda.synchronize()
    time0 = time.time()

    while (t < tminmax[:, :, :, 1]).any():
        valid2 = torch.ones_like(rayrgba[:, :, :, 3:4])

        for k in range(K):
            y0 = torch.bmm(
                    (raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)),
                    primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :]

            fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True))

            if dowarp:
                y1 = F.grid_sample(
                        warp[:, k, :, :, :, :],
                        y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
            else:
                y1 = y0

            sample = F.grid_sample(
                    template[:, k, :, :, :, :],
                    y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)

            valid1 = (
                torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) *
                torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True))

            valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None]

            alpha0 = sample[:, :, :, 3:4]

            rgb = sample[:, :, :, 0:3] * valid * valid1
            alpha = alpha0 * fade * stepsize * valid * valid1

            if accum == 0:
                newalpha = rayrgba[:, :, :, 3:4] + alpha
                contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1
                rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1)
            else:
                raise

        step += 1
        t = t0 + stepsize * step
        raypos = raypos0 + raydir * stepsize * step

    print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item())

    sample0 = rayrgba

    torch.cuda.synchronize()
    time1 = time.time()

    sample0.backward(torch.ones_like(sample0))

    torch.cuda.synchronize()
    time2 = time.time()

    print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total"))
    print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0))

    grads0 = [p.grad.detach().clone() for p in params]

    for p in params:
        p.grad.detach_()
        p.grad.zero_()

    ############################## run cuda version ###########################

    raypos = _raypos
    raydir = _raydir
    stepsize = _stepsize
    tminmax = _tminmax

    template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
    warp = _warp.to("cuda")
    if chlast:
        template = template.permute(0, 1, 3, 4, 5, 2).contiguous()
        warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous()
    primpos = _primpos.to("cuda") * 0.3
    primrot = _primrot.to("cuda")
    primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))

    niter = 1

    tf, tb = 0., 0.
    for i in range(niter):
        for p in params:
            try:
                p.grad.detach_()
                p.grad.zero_()
            except:
                pass
        t0 = time.time()
        torch.cuda.synchronize()
        sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax,
                (primpos, primrot, primscale),
                template, warp if dowarp else None,
                algo=algo, usebvh=usebvh, sortprims=sortprims, 
                maxhitboxes=maxhitboxes, synchitboxes=synchitboxes,
                chlast=chlast, fadescale=fadescale, fadeexp=fadeexp,
                accum=accum, termthresh=termthreshorig,
                griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize)
        t1 = time.time()
        torch.cuda.synchronize()
        sample1.backward(torch.ones_like(sample1), retain_graph=True)
        torch.cuda.synchronize()
        t2 = time.time()
        tf += t1 - t0
        tb += t2 - t1

    print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
    grads1 = [p.grad.detach().clone() for p in params]

    ############# compare results #############

    print("-----------------------------------------------------------------")
    print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda"))
    ind = torch.argmax(torch.abs(sample0 - sample1))
    print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
        "fwd",
        torch.max(torch.abs(sample0 - sample1)).item(),
        (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
        torch.sqrt(torch.sum(sample0 * sample0)).item(),
        torch.sqrt(torch.sum(sample1 * sample1)).item(),
        ind.item(),
        sample0.view(-1)[ind].item(),
        sample1.view(-1)[ind].item()))

    for p, g0, g1 in zip(paramnames, grads0, grads1):
        ind = torch.argmax(torch.abs(g0 - g1))
        print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
                p,
                torch.max(torch.abs(g0 - g1)).item(),
                (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
                torch.sqrt(torch.sum(g0 * g0)).item(),
                torch.sqrt(torch.sum(g1 * g1)).item(),
                ind.item(),
                g0.view(-1)[ind].item(),
                g1.view(-1)[ind].item()))