def testbootstrap2()

in bootstrap2.py [0:0]


def testbootstrap2(filein, fileout, subsampling_factor, recon, recon_args,
                   n_resamps=None, sdn=None, viz=None):
    """
    tests bootstrap2

    Runs bootstrap2 (and prints the losses) with filein, fileout,
    subsampling_factor, recon, recon_args, n_resamps, sdn, and viz, creating
    a random mask that retains each radial "line" with probability
    subsampling_factor.

    The calling sequence of recon must be  (m, n, f, mask, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask is the return from calls to radialines (with angles),
    and **recon_args is the unpacking of recon_args.
    The function recon must return a torch.Tensor (the reconstruction) and
    a float (the corresponding loss).

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    subsampling_factor : float
        probability of retaining a radial "line" in the subsampling masks
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    n_resamps : int, optional
        number of bootstrap resampled reconstructions (defaults to 100
        in bootstrap2)
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0 in bootstrap2)
    viz : bool, optional
        indicator of whether to generate colorized visualizations
        (defaults to False in bootstrap2)
    """
    # Obtain the size of the input image.
    with Image.open(filein) as img:
        n, m = img.size
    # Select which frequencies to retain.
    angles = np.random.uniform(low=0, high=(2 * np.pi),
                               size=round(2 * (m + n) * subsampling_factor))
    # Generate bootstrap plots.
    low = 0
    loss, losses, rsse_estimated, rsse_blurred = bootstrap2(
        filein, fileout, subsampling_factor, angles, low, recon, recon_args,
        n_resamps, sdn, viz)
    # Display the losses.
    print('loss = {}'.format(loss))
    print('losses = {}'.format(losses))
    # Display the estimated root-sum-square errors.
    print('Frobenius norm of the bootstrap = {}'.format(rsse_estimated))
    print('Frobenius norm of the blurred bootstrap = {}'.format(rsse_blurred))