def testbootstrap()

in bootstrap.py [0:0]


def testbootstrap(filein, fileout, subsampling_factor, recon, recon_args,
                  n_resamps=None, sdn=None, viz=None):
    """
    tests bootstrap

    Runs bootstrap (and prints the losses) with filein, fileout,
    subsampling_factor, recon, recon_args, n_resamps, sdn, and viz, creating
    a random mask that retains each row with probability subsampling_factor,
    supplemented by all rows between -sqrt(2m) and sqrt(2m), where filein
    contains an image with m rows.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    subsampling_factor : float
        probability of retaining a row in the subsampling masks
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    n_resamps : int, optional
        number of bootstrap resampled reconstructions (defaults to 100
        in bootstrap)
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0 in bootstrap)
    viz : bool, optional
        indicator of whether to generate colorized visualizations
        (defaults to False in bootstrap)
    """
    # Obtain the size of the input image.
    with Image.open(filein) as img:
        n, m = img.size
    # Select which frequencies to retain.
    maski = set(
        np.floor(m * np.random.uniform(size=round(m * subsampling_factor))))
    mask = np.asarray([False] * m, dtype=bool)
    for i in maski:
        mask[int(i)] = True
    # Make the optimization well-posed by including the zero frequency.
    mask[0] = True
    # Include all low frequencies.
    low = round(math.sqrt(2. * m))
    for k in range(low):
        mask[k] = True
        mask[-k] = True
    # Generate bootstrap plots.
    loss, losses, rsse_estimated, rsse_blurred = bootstrap(
        filein, fileout, subsampling_factor, mask, low, recon, recon_args,
        n_resamps, sdn, viz)
    # Display the losses.
    print('loss = {}'.format(loss))
    print('losses = {}'.format(losses))
    # Display the estimated root-sum-square errors.
    print('Frobenius norm of the bootstrap = {}'.format(rsse_estimated))
    print('Frobenius norm of the blurred bootstrap = {}'.format(rsse_blurred))