def jackknife()

in jackknife.py [0:0]


def jackknife(filein, fileout, mask, low, recon, recon_args, sdn=None):
    """
    plots twice the sum of the leave-one-out errors of reconstructions

    Plots and saves twice the sum of the leave-one-out differences between
    the original image in filein and the reconstruction via recon applied to
    the k-space subsampling specified by mask (well, assuming mask includes
    all frequencies between -low and low), corrupting the k-space values with
    independent and identically distributed centered complex Gaussian noise
    whose standard deviation is sdn*sqrt(2) (sdn=0 if not provided explicitly).
    The "one" left out in the leave-one-out is a full row of k-space.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    _N.B._: mask[-low+1], mask[-low+2], ..., mask[low-1] must be True.

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    mask : ndarray of bool
        indicators of whether to include (True) or exclude (False)
        the corresponding rows in k-space of the image from filein
    low : int
        bandwidth of low frequencies included in mask (between -low to low)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)

    Returns
    -------
    float
        loss for the reconstruction using all mask
    list of float
        losses for the reconstructions using all mask except for one row
    """
    # Set default parameters.
    if sdn is None:
        sdn = 0
    # Check that the mask includes all low frequencies.
    for k in range(low):
        assert mask[k]
        assert mask[-k]
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    # Index the True values in mask (aside from the low frequencies);
    # make the inequality strict to allow for low = 0.
    trues = []
    for k in range(mask.size):
        if k > low and k < m - low and mask[k]:
            trues.append(k)
    logging.info(
        'computing jackknife differences -- all {}'.format(len(trues)))
    # Perform the reconstruction using the entire mask.
    mask_th = torch.from_numpy(mask.astype(np.uint8)).cuda()
    reconf, lossf = recon(m, n, f, mask_th, **recon_args)
    reconf = reconf.cpu().numpy()
    # Perform the reconstruction omitting different samples in k-space.
    recons = np.ndarray((len(trues), m, n))
    loss = []
    for k in range(len(trues)):
        # Drop a row.
        mask1 = mask.copy()
        mask1[trues[k]] = False
        f1 = ctorch.from_numpy(ff_noisy[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        mask1_th = torch.from_numpy(mask1.astype(np.uint8)).cuda()
        recon1, loss1 = recon(m, n, f1, mask1_th, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the leave-one-out differences.
    sumloo = np.sum(recons - reconf, axis=0)
    scaled = sumloo * 2

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot twice the sum of the leave-one-out differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Jackknife')
    plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
    plt.savefig(rest + '_jackknife' + suffix, bbox_inches='tight')

    return lossf, loss