def cs_fft()

in admm.py [0:0]


def cs_fft(m, n, f, mask, mu, beta, n_iter):
    """
    Recovers an image from a subset of its frequencies using FFTs.

    Reconstructs an m x n image from the subset f of its frequencies specified
    by mask, using ADMM with regularization parameter mu, coupling parameter
    beta, and number of iterations n_iter. Unlike function cs_baseline,
    this cs_fft uses FFTs. The computations take place on the CPU(s) in numpy
    when f is a numpy.ndarray and take place on the GPU(s) in ctorch when f is
    a ctorch.ComplexTensor.

    _N.B._: mask[0] must be True in order to make the optimization well-posed.

    Parameters
    ----------
    m : int
        number of rows in the image being reconstructed
    n : int
        number of columns in the image being reconstructed
    f : numpy.ndarray or ctorch.ComplexTensor
        potentially nonzero rows (prior to the inverse Fourier transform)
    mask : numpy.ndarray
        boolean indicators of the positions of the rows in the full m x n array
        -- note that the zero frequency entry must be True in order to make the
        optimization well-posed
    mu : float
        regularization parameter
    beta : float
        coupling parameter for the ADMM iterations
    n_iter : int
        number of ADMM iterations to conduct

    Returns
    -------
    numpy.ndarray or ctorch.ComplexTensor
        reconstructed m x n image
    float
        objective value at the end of the ADMM iterations (see function adm)
    """

    def image_gradient(x):
        """
        First-order finite-differencing both horizontally and vertically.

        Computes a first-order finite-difference approximation to the gradient.

        Parameters
        ----------
        x : numpy.ndarray or ctorch.ComplexTensor
            image (that is, two-dimensional array)

        Returns
        -------
        numpy.ndarray or ctorch.ComplexTensor
            horizontal finite differences of x stacked on top of the vertical
            finite differences (separating horizontal from vertical via the
            initial dimension)
        """
        if isinstance(x, np.ndarray):
            # Wrap the last column of x around to the beginning.
            x_h = np.hstack((x[:, -1:], x))
            # Wrap the last row of x around to the beginning.
            x_v = np.vstack((x[-1:], x))
            # Apply forward differences to the columns of x.
            d_x = (x_h[:, 1:] - x_h[:, :-1])
            # Apply forward differences to the rows of x.
            d_y = (x_v[1:] - x_v[:-1])
            return np.vstack((d_x.ravel(), d_y.ravel()))
        elif isinstance(x, ctorch.ComplexTensor):
            # Wrap the last column of x around to the beginning.
            x_h = ctorch.cat((x[:, -1:], x), dim=1)
            # Wrap the last row of x around to the beginning.
            x_v = ctorch.cat((x[-1:], x), dim=0)
            # Apply forward differences to the columns of x.
            d_x = (x_h[:, 1:] - x_h[:, :-1])
            # Apply forward differences to the rows of x.
            d_y = (x_v[1:] - x_v[:-1])
            return ctorch.cat((d_x, d_y)).view(2, -1)
        else:
            raise TypeError('Input must be a numpy.ndarray ' +
                            'or a ctorch.ComplexTensor.')

    def image_gradient_T(x):
        """
        Transpose of the operator that function image_gradient implements.

        Computes the transpose of the matrix given by function image_gradient.

        Parameters
        ----------
        x : numpy.ndarray or ctorch.ComplexTensor
            stack of two identically shaped arrays

        Returns
        -------
        numpy.ndarray or ctorch.ComplexTensor
            result of applying to x the transpose of function image_gradient
        """
        if isinstance(x, np.ndarray):
            x_h = x[0]
            x_v = x[1]
            # Wrap the first column of x_h around to the end.
            x_h_ext = np.hstack((x_h, x_h[:, :1]))
            # Wrap the first row of x_v around to the end.
            x_v_ext = np.vstack((x_v, x_v[:1]))
            # Apply forward differences to the columns of x.
            d_x = x_h_ext[:, :-1] - x_h_ext[:, 1:]
            # Apply forward differences to the rows of x.
            d_y = x_v_ext[:-1] - x_v_ext[1:]
            return d_x + d_y
        elif isinstance(x, ctorch.ComplexTensor):
            x_h = x[0]
            x_v = x[1]
            # Wrap the first column of x_h around to the end.
            x_h_ext = ctorch.cat((x_h, x_h[:, :1]), dim=1)
            # Wrap the first row of x_v around to the end.
            x_v_ext = ctorch.cat((x_v, x_v[:1]), dim=0)
            # Apply forward differences to the columns of x.
            d_x = x_h_ext[:, :-1] - x_h_ext[:, 1:]
            # Apply forward differences to the rows of x.
            d_y = x_v_ext[:-1] - x_v_ext[1:]
            return d_x + d_y
        else:
            raise TypeError('Input must be a numpy.ndarray ' +
                            'or a ctorch.ComplexTensor.')

    if isinstance(f, np.ndarray):
        assert f.shape[1] == n
        assert mask[0]
        # Rescale f and pad with zeros between the mask samples.
        Ktf = (mu / beta) * zero_pad(m, n, f, mask)
        # Calculate the Fourier transform of the convolutional kernels
        # for finite differences.
        tx = np.abs(np.fft.fft([1, -1] + [0] * (m - 2)))**2
        ty = np.abs(np.fft.fft([1, -1] + [0] * (n - 2)))**2
        # Compute the multipliers required to solve formula (2.8) from Tao-Yang
        # in the Fourier domain. The calculation involves broadcasting the
        # Fourier transform of the convolutional kernel for horizontal finite
        # differences over the vertical directions, and broadcasting both the
        # subsampling mask and the Fourier transform of the convolutional
        # kernel for vertical finite differences over horizontal directions.
        multipliers = 1. / (ty + tx[:, None] + (mu / beta) * mask[:, None])
        # Initialize the primal (x) and dual (la) solutions to zeros.
        x = np.zeros((m, n))
        la = np.zeros((2, m * n))
        # Calculate iterations of alternating minimization.
        for i in range(n_iter):
            # Apply shrinkage via formula (2.7) from Tao-Yang, dividing both
            # arguments of the "max" operator in formula (2.7) by the
            # denominator of the rightmost factor in formula (2.7).
            a = image_gradient(x) + la / beta
            b = scipy.linalg.norm(a, axis=0, keepdims=True)
            if i > 0:
                y = a * np.maximum(1 - 1 / (beta * b), 0)
            else:
                y = np.zeros((2, m * n))
            # Solve formula (2.8) from Tao-Yang in the Fourier domain.
            c = image_gradient_T((y - la / beta).reshape((2, m, n))) + Ktf
            x = np.fft.ifft2(np.fft.fft2(c) * multipliers)
            # Update the Lagrange multipliers via formula (2.9) from Tao-Yang.
            la = la - beta * (y - image_gradient(x))
        # Calculate the loss in formula (1.4) from Tao-Yang...
        loss = np.linalg.norm(image_gradient(x), axis=0).sum()
        # ... adding in the term for the fidelity of the reconstruction.
        loss += np.linalg.norm(
            np.fft.fft2(x)[mask] / np.sqrt(m * n) - f)**2 * (mu / 2)
        # Discard the imaginary part of the primal solution,
        # returning only the real part and the loss.
        return x.real, loss
    elif isinstance(f, ctorch.ComplexTensor):
        assert f.shape[1] == n
        assert mask[0]
        # Convert the mask from booleans to long integers.
        mask_nnz = torch.nonzero(mask, as_tuple=False).squeeze(1)
        # Rescale f and pad with zeros between the mask samples.
        Ktf = zero_pad(m, n, f, mask_nnz) * (mu / beta)
        # Calculate the Fourier transform of the convolutional kernels
        # for finite differences.
        tx = np.abs(np.fft.fft([1, -1] + [0] * (m - 2)))**2
        ty = np.abs(np.fft.fft([1, -1] + [0] * (n - 2)))**2
        # Compute the multipliers required to solve formula (2.8) from Tao-Yang
        # in the Fourier domain. The calculation involves broadcasting the
        # Fourier transform of the convolutional kernel for horizontal finite
        # differences over the vertical directions, and broadcasting both the
        # subsampling mask and the Fourier transform of the convolutional
        # kernel for vertical finite differences over horizontal directions.
        multipliers = 1. / (ty + tx[:, None] + mask.cpu().numpy()[:, None] *
                            (mu / beta))
        multipliers = ctorch.from_numpy(multipliers).cuda()
        # Initialize the primal (x) and dual (la) solutions to zeros,
        # creating new ctorch tensors of the same type as f.
        x = f.new(m, n).zero_()
        la = f.new(2, m * n).zero_()
        # Calculate iterations of alternating minimization.
        for i in range(n_iter):
            # Apply shrinkage via formula (2.7) from Tao-Yang, dividing both
            # arguments of the "max" operator in formula (2.7) by the
            # denominator of the rightmost factor in formula (2.7).
            a = image_gradient(x) + la / beta
            b = ctorch.norm(a, p=2, dim=0, keepdim=True)
            if i > 0:
                y = a * torch.clamp(1 - 1 / (beta * b), min=0)
            else:
                y = f.new(2, m * n).zero_()
            # Solve formula (2.8) from Tao-Yang in the Fourier domain.
            c = image_gradient_T((y - la / beta).view(2, m, n)) + Ktf
            x = ctorch.ifft2(ctorch.fft2(c) * multipliers)
            # Update the Lagrange multipliers via formula (2.9) from Tao-Yang.
            la = la - (y - image_gradient(x)) * beta
        # Calculate the loss in formula (1.4) from Tao-Yang...
        loss = ctorch.norm(image_gradient(x), p=2, dim=0).sum()
        # ... adding in the term for the fidelity of the reconstruction.
        loss += ctorch.norm(
            ctorch.fft2(x)[mask_nnz] / math.sqrt(m * n) - f)**2 * (mu / 2)
        # Discard the imaginary part of the primal solution,
        # returning only the real part and the loss.
        return x.real, loss.item()
    else:
        raise TypeError('Input must be a numpy.ndarray ' +
                        'or a ctorch.ComplexTensor.')