def gradcheck()

in extensions/utils/utils.py [0:0]


def gradcheck():
    N = 2
    H = 64
    W = 64
    k3 = 4
    K = k3*k3*k3

    M = 32
    volradius = 1.

    # generate random inputs
    torch.manual_seed(1113)

    rodrigues = Rodrigues()

    _viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1
    viewrvec = torch.randn(N, 3, device="cuda") * 0.01
    _viewrot = rodrigues(viewrvec)

    _focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda")
    _princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda")
    pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float())
    _pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)

    _viewpos = _viewpos.contiguous().detach().clone()
    _viewpos.requires_grad = True
    _viewrot = _viewrot.contiguous().detach().clone()
    _viewrot.requires_grad = True
    _focal = _focal.contiguous().detach().clone()
    _focal.requires_grad = True
    _princpt = _princpt.contiguous().detach().clone()
    _princpt.requires_grad = True
    _pixelcoords = _pixelcoords.contiguous().detach().clone()
    _pixelcoords.requires_grad = True

    max_len = 6.0
    _stepsize = max_len / 15.5

    params = [_viewpos, _viewrot, _focal, _princpt]
    paramnames = ["viewpos", "viewrot", "focal", "princpt"]

    ########################### run pytorch version ###########################

    viewpos = _viewpos
    viewrot = _viewrot
    focal = _focal
    princpt = _princpt
    pixelcoords = _pixelcoords

    raypos = viewpos[:, None, None, :].repeat(1, H, W, 1)

    raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
    raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
    raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
    raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))

    t1 = (-1. - viewpos[:, None, None, :]) / raydir
    t2 = ( 1. - viewpos[:, None, None, :]) / raydir
    tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
           torch.max(torch.min(t1[..., 1], t2[..., 1]),
                     torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.)
    tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
           torch.min(torch.max(t1[..., 1], t2[..., 1]),
                     torch.max(t1[..., 2], t2[..., 2])))

    tminmax = torch.stack([tmin, tmax], dim=-1)

    sample0 = raydir

    torch.cuda.synchronize()
    time1 = time.time()

    sample0.backward(torch.ones_like(sample0))

    torch.cuda.synchronize()
    time2 = time.time()

    grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params]

    for p in params:
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()

    ############################## run cuda version ###########################

    viewpos = _viewpos
    viewrot = _viewrot
    focal = _focal
    princpt = _princpt
    pixelcoords = _pixelcoords

    niter = 1

    for p in params:
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()
    t0 = time.time()
    torch.cuda.synchronize()

    sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1]

    t1 = time.time()
    torch.cuda.synchronize()

    print("-----------------------------------------------------------------")
    print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda"))
    ind = torch.argmax(torch.abs(sample0 - sample1))
    print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
        "fwd",
        torch.max(torch.abs(sample0 - sample1)).item(),
        (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
        ind.item(),
        sample0.view(-1)[ind].item(),
        sample1.view(-1)[ind].item()))

    sample1.backward(torch.ones_like(sample1), retain_graph=True)

    torch.cuda.synchronize()
    t2 = time.time()


    print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
    grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params]

    ############# compare results #############

    for p, g0, g1 in zip(paramnames, grads0, grads1):
        ind = torch.argmax(torch.abs(g0 - g1))
        print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
                p,
                torch.max(torch.abs(g0 - g1)).item(),
                (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
                ind.item(),
                g0.view(-1)[ind].item(),
                g1.view(-1)[ind].item()))