def testjackknife()

in jackknife.py [0:0]


def testjackknife(filein, fileout, subsampling_factor, recon, recon_args,
                  sdn=None):
    """
    tests jackknife

    Runs jackknife (and prints the losses) with filein, fileout, recon,
    recon_args, and sdn, creating a random mask that retains each row with
    probability subsampling_factor, supplemented by all rows between -sqrt(2m)
    and sqrt(2m), where filein contains an image with m rows.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    subsampling_factor : float
        probability of retaining a row in the subsampling mask
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0 in jackknife)
    """
    # Obtain the size of the input image.
    with Image.open(filein) as img:
        n, m = img.size
    # Select which frequencies to retain.
    maski = set(
        np.floor(m * np.random.uniform(size=round(m * subsampling_factor))))
    mask = np.asarray([False] * m, dtype=bool)
    for i in maski:
        mask[int(i)] = True
    # Make the optimization well-posed by including the zero frequency.
    mask[0] = True
    # Include all low frequencies.
    low = round(math.sqrt(2. * m))
    for k in range(low):
        mask[k] = True
        mask[-k] = True
    # Generate jackknife plots.
    loss, losses = jackknife(filein, fileout, mask, low, recon, recon_args,
                             sdn)
    # Display the losses.
    print('loss = {}'.format(loss))
    print('losses = {}'.format(losses))