def bootstrap()

in bootstrap.py [0:0]


def bootstrap(filein, fileout, subsampling_factor, mask, low, recon,
              recon_args, n_resamps=None, sdn=None, viz=None):
    """
    plots thrice the average of bootstrapped errors in reconstruction

    Plots and saves thrice the average of the differences between
    a reconstruction of the original image in filein and n_resamps bootstrap
    reconstructions via recon applied to the k-space subsamplings specified
    by mask and by other masks generated similarly (retaining each row of
    the image with probability given by subsampling_factor, then adding all
    frequencies between -low and low), corrupting the k-space values with
    independent and identically distributed centered complex Gaussian noise
    whose standard deviation is sdn*sqrt(2) (sdn=0 if not provided explicitly).
    Setting viz to be True yields colorized visualizations, too, including
    the error estimates overlaid over the reconstruction, the error estimates
    blurred overlaid over the reconstruction, the error estimates blurred,
    the error estimates subtracted from the reconstruction, the error estimates
    saturating the reconstruction in hue-saturation-value (HSV) color space,
    and the error estimates interpolating the reconstruction in HSV space.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    _N.B._: mask[-low+1], mask[-low+2], ..., mask[low-1] must be True.

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    subsampling_factor : float
        probability of retaining a row in the subsampling masks
    mask : ndarray of bool
        indicators of whether to include (True) or exclude (False)
        the corresponding rows in k-space of the image from filein
    low : int
        bandwidth of low frequencies included in mask (between -low to low)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    n_resamps : int, optional
        number of bootstrap resampled reconstructions (defaults to 100)
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)
    viz : bool, optional
        indicator of whether to generate colorized visualizations
        (defaults to False)

    Returns
    -------
    float
        loss for the reconstruction using the original mask
    list of float
        losses for the reconstructions using other, randomly generated masks
    float
        square root of the sum of the square of the estimated errors
    float
        square root of the sum of the square of the estimated errors blurred
    """
    # Set default parameters.
    if n_resamps is None:
        n_resamps = 100
    if sdn is None:
        sdn = 0
    if viz is None:
        viz = False
    # Check that the mask includes all low frequencies.
    for k in range(low):
        assert mask[k]
        assert mask[-k]
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    logging.info('computing bootstrap resamplings -- all {}'.format(n_resamps))
    # Perform the reconstruction using the mask.
    mask_th = torch.from_numpy(mask.astype(np.uint8)).cuda()
    reconf, lossf = recon(m, n, f, mask_th, **recon_args)
    reconf = reconf.cpu().numpy()
    # Fourier transform the reconstruction.
    freconf = np.fft.fft2(reconf) / np.sqrt(m * n)
    # Perform the reconstruction resampling new masks and samples in k-space.
    recons = np.ndarray((n_resamps, m, n))
    loss = []
    for k in range(n_resamps):
        # Select which frequencies to retain.
        maski = set(np.floor(
            m * np.random.uniform(size=round(m * subsampling_factor))))
        mask1 = np.asarray([False] * m, dtype=bool)
        for i in maski:
            mask1[int(i)] = True
        # Make the optimization well-posed by including the zero frequency.
        mask1[0] = True
        # Include all low frequencies.
        for j in range(low):
            mask1[j] = True
            mask1[-j] = True
        # Subsample the Fourier transform of the reconstruction.
        f1 = ctorch.from_numpy(freconf[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        mask1_th = torch.from_numpy(mask1.astype(np.uint8)).cuda()
        recon1, loss1 = recon(m, n, f1, mask1_th, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the bootstrap differences.
    sumboo = np.sum(recons - reconf, axis=0)
    scaled = sumboo * 3 / n_resamps
    # Blur the error estimates.
    sigma = 1
    blurred = skimage.filters.gaussian(scaled, sigma=sigma)
    rsse_estimated = np.linalg.norm(scaled, ord='fro')
    rsse_blurred = np.linalg.norm(blurred, ord='fro')

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot thrice the average of the bootstrap differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Bootstrap')
    plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
    plt.savefig(rest + '_bootstrap' + suffix, bbox_inches='tight')

    if viz:
        # Plot the reconstruction minus the bootstrap difference.
        plt.figure(figsize=(5.5, 5.5))
        plt.title('Reconstruction \u2013 Bootstrap')
        plt.imshow(np.clip(reconf - scaled, 0, 1), **kwargs01)
        plt.savefig(rest + '_corrected' + suffix, bbox_inches='tight')
        # Overlay the error estimates on the reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        threshold = np.abs(scaled).flatten()
        threshold = np.sort(threshold)
        maxthresh = threshold[-1]
        threshold = threshold[round(0.98 * threshold.size)]
        hue = 2. / 3 + (scaled / maxthresh) / 4 * 2 / 3
        saturation = np.abs(scaled) > threshold
        value = reconf * (1 - saturation) + saturation
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Errors Over a Threshold Overlaid')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_overlaid' + suffix, bbox_inches='tight')
        # Overlay the blurred error estimates on the reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        threshold = np.abs(blurred).flatten()
        threshold = np.sort(threshold)
        maxthresh = threshold[-1]
        threshold = threshold[round(0.98 * threshold.size)]
        hue = 2. / 3 + (blurred / maxthresh) / 4 * 2 / 3
        saturation = np.abs(blurred) > threshold
        value = reconf * (1 - saturation) + saturation
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Blurred Errors Over a Threshold Overlaid')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_blurred_overlaid' + suffix, bbox_inches='tight')
        # Plot a bootstrap-saturated reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        hue = (1 - np.sign(scaled)) / 4 * 2 / 3
        saturation = np.abs(scaled)
        saturation = saturation / np.max(saturation)
        value = np.clip(reconf, 0, 1)
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Bootstrap-Saturated Reconstruction')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_saturated' + suffix, bbox_inches='tight')
        # Plot a bootstrap-interpolated reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        hue = 7. / 12 + np.sign(scaled) * 3 / 12
        saturation = np.abs(scaled)
        saturation = saturation / np.max(saturation)
        value = np.clip(reconf, 0, 1)
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Bootstrap-Interpolated Reconstruction')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_interpolated' + suffix, bbox_inches='tight')
        # Plot the blurred bootstrap.
        plt.figure(figsize=(5.5, 5.5))
        plt.title('Blurred Bootstrap')
        plt.imshow(np.clip(blurred, -1, 1), **kwargs11)
        plt.savefig(rest + '_blurred' + suffix, bbox_inches='tight')

    return lossf, loss, rsse_estimated, rsse_blurred