def _fft()

in ctorch.py [0:0]


def _fft(x, rank, direction, plan_cache={}):
    assert isinstance(x, ComplexTensor)
    assert isinstance(x.real, torch.cuda.DoubleTensor)
    assert x.real.dim() >= rank
    orig_shape = x.shape
    if x.real.dim() == rank:
        x = x.unsqueeze(0)
    else:
        x = x.view(-1, *x.shape[-rank:])
    plan = _plan(*x.shape)
    x_stack = torch.stack((x.real, x.imag), dim=rank + 1)
    y_stack = x_stack.new(x_stack.shape)
    assert cufft.cufftExecZ2Z(
        plan, ctypes.c_void_p(x_stack.data_ptr()),
        ctypes.c_void_p(y_stack.data_ptr()), direction) == CUFFT_SUCCESS
    torch.cuda.synchronize()
    y_real, y_imag = y_stack.split(1, dim=rank + 1)
    y = ComplexTensor(y_real, y_imag).contiguous().view(orig_shape)
    if direction == CUFFT_INVERSE:
        y = y / float(np.prod(x.shape[-rank:]))
    return y