def runtestadmm()

in admm2.py [0:0]


def runtestadmm(method, cpu, filename, mu=1e12, beta=1, subsampling_factor=0.7,
                n_iter=100, seed=None):
    """Run tests as specified.

    Use the specified method (on CPUs if cpu is True), reading the image from
    filename, with the lasso regularization parameter mu and the ADMM coupling
    parameter beta, subsampling by subsampling_factor, for n_iter iterations
    of ADMM, seeding the random number generator with the provided seed.

    Parameters
    ----------
    method : str
        which algorithm to use ('cs_baseline' or 'cs_fft')
    cpu : boolean
        set to true to perform all computations on the CPU(s)
    filename : str
        name of the file containing the image to be processed; prepend a path
        if the file resides outside the working directory
    mu : float
        regularization parameter
    beta : float
        coupling parameter for the ADMM iterations
    subsampling_factor : float
        probability of retaining an entry in k-space
    n_iter : int
        number of ADMM iterations to conduct
    seed : int
        seed value for numpy's random number generator

    Returns
    -------
    float
        objective value at the end of the ADMM iterations (see function adm)
    """

    def tic():
        """
        Timing starting.

        Records the current time.

        Returns
        -------
        float
            present time in fractional seconds
        """
        torch.cuda.synchronize()
        return time.perf_counter()

    def toc(t):
        """
        Timing stopping.

        Reports the difference of the current time from the reference provided.

        Parameters
        ----------
        t : float
            reference time in fractional seconds

        Returns
        -------
        float
            difference of the present time from the reference t
        """
        torch.cuda.synchronize()
        return time.perf_counter() - t

    # Fix the random seed if appropriate.
    np.random.seed(seed=seed)
    # Read the image from disk.
    f_orig = np.array(Image.open(filename)).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Select which k-space frequencies to retain.
    mask = np.random.uniform(size=(m, n)) < subsampling_factor
    # Make the optimization well-posed by including the zero frequency.
    mask[0, 0] = True
    # Subsample the Fourier transform of the original image.
    f = np.fft.fft2(f_orig)[mask] / np.sqrt(m * n)
    # Start timing.
    t = tic()
    # Reconstruct the image from the undersampled Fourier data.
    print('Running {}(cpu={}, mu={}, beta={}, n_iter={})'.format(
        method, cpu, mu, beta, n_iter))
    if method == 'cs_baseline':
        if cpu:
            x, loss = cs_baseline(m, n, f, mask, mu=mu, beta=beta,
                                  n_iter=n_iter)
        else:
            raise NotImplementedError('A baseline on GPUs is not implemented' +
                                      '; use \'--cpu\'')
    elif method == 'cs_fft':
        if cpu:
            x, loss = cs_fft(m, n, f, mask, mu=mu, beta=beta, n_iter=n_iter)
        else:
            # Move the Fourier data to the GPUs.
            f_th = ctorch.from_numpy(f).cuda()
            # The first call to `ctorch.fft2` is slow;
            # run a dummy fft2 and restart the timer to get accurate timings.
            ctorch.fft2(ctorch.from_numpy(np.fft.fft2(f_orig)).cuda())
            t = tic()
            x, loss = cs_fft(m, n, f_th, mask, mu=mu, beta=beta, n_iter=n_iter)
            x = x.cpu().numpy()
    else:
        raise NotImplementedError('method must be either \'cs_baseline\' ' +
                                  'or \'cs_fft\'')
    # Stop timing.
    tt = toc(t)
    # Print the time taken and final loss.
    print('time={}s'.format(tt))
    print('loss={}'.format(loss))
    # Plot the original image, its reconstruction, and the sampling pattern.
    plt.figure(figsize=(12, 12))
    plt.subplot(221)
    plt.title('Original')
    plt.imshow(f_orig, cmap='gray')
    plt.subplot(222)
    plt.title('Compressed sensing reconstruction')
    plt.imshow(x.reshape(m, n), cmap='gray')
    plt.subplot(223)
    plt.title('Naive (zero-padded ifft2) reconstruction')
    plt.imshow(np.abs(zero_pad(m, n, f, mask)), cmap='gray')
    plt.subplot(224)
    plt.title('Sampling mask')
    plt.imshow(mask, cmap='gray')
    plt.savefig('recon2.png', bbox_inches='tight')
    return loss