def jackknife2()

in jackknife2.py [0:0]


def jackknife2(filein, fileout, angles, low, recon, recon_args, sdn=None):
    """
    plots twice the sum of the leave-one-out errors of reconstructions

    Plots and saves twice the sum of the leave-one-out differences between
    the original image in filein and the reconstruction via recon applied to
    the k-space subsampling specified by angles fed into radialines, while
    including all frequencies between -low to low in both directions,
    corrupting the k-space values with independent and identically distributed
    centered complex Gaussian noise whose standard deviation is sdn*sqrt(2)
    (sdn=0 if not provided explicitly). The "one" left out in the leave-one-out
    is a radial "line" at the angles specified by angles.

    The calling sequence of recon must be  (m, n, f, mask, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask is the return from calls to radialines (with angles),
    supplemented by all frequencies between -low to low in both directions, and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    angles : list of float
        angles of the radial "lines" in the mask that radialines will construct
    low : int
        bandwidth of low frequencies to include in the mask
        (between -low to low in both the horizontal and vertical directions)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)

    Returns
    -------
    float
        loss for the reconstruction using all angles
    list of float
        losses for the reconstructions using all angles except for one
    """
    # Set default parameters.
    if sdn is None:
        sdn = 0
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Select which frequencies to retain.
    mask = radialines.randradialineset(m, n, angles)
    # Include all low frequencies.
    for km in range(low):
        for kn in range(low):
            mask[km, kn] = True
            mask[m - 1 - km, kn] = True
            mask[km, n - 1 - kn] = True
            mask[m - 1 - km, n - 1 - kn] = True
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    logging.info(
        'computing jackknife2 differences -- all {}'.format(len(angles)))
    # Perform the reconstruction using the entire mask.
    reconf, lossf = recon(m, n, f, mask, **recon_args)
    reconf = reconf.cpu().numpy()
    # Perform the reconstruction omitting different samples in k-space.
    recons = np.ndarray((angles.size, m, n))
    loss = []
    for k in range(angles.size):
        # Drop an angle.
        langles = list(angles)
        del langles[k]
        mask1 = radialines.randradialineset(m, n, langles)
        # Include all low frequencies.
        for km in range(low):
            for kn in range(low):
                mask1[km, kn] = True
                mask1[m - 1 - km, kn] = True
                mask1[km, n - 1 - kn] = True
                mask1[m - 1 - km, n - 1 - kn] = True
        # Subsample the noisy Fourier transform of the original image.
        f1 = ctorch.from_numpy(ff_noisy[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        recon1, loss1 = recon(m, n, f1, mask1, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the leave-one-out differences.
    sumloo = np.sum(recons - reconf, axis=0)
    scaled = sumloo * 2

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot twice the sum of the leave-one-out differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Jackknife')
    plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
    plt.savefig(rest + '_jackknife' + suffix, bbox_inches='tight')

    return lossf, loss