in bootstrap.py [0:0]
def testbootstrap(filein, fileout, subsampling_factor, recon, recon_args,
n_resamps=None, sdn=None, viz=None):
"""
tests bootstrap
Runs bootstrap (and prints the losses) with filein, fileout,
subsampling_factor, recon, recon_args, n_resamps, sdn, and viz, creating
a random mask that retains each row with probability subsampling_factor,
supplemented by all rows between -sqrt(2m) and sqrt(2m), where filein
contains an image with m rows.
The calling sequence of recon must be (m, n, f, mask_th, **recon_args),
where filein contains an m x n image, f is the image in k-space subsampled
to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
**recon_args is the unpacking of recon_args. The function recon must return
a torch.Tensor (the reconstruction) and a float (the corresponding loss).
Parameters
----------
filein : str
path to the file containing the image to be processed (the path may be
relative or absolute)
fileout : str
path to the file to which the plots will be saved (the path may be
relative or absolute)
subsampling_factor : float
probability of retaining a row in the subsampling masks
recon : function
returns the reconstructed image
recon_args : dict
keyword arguments for recon
n_resamps : int, optional
number of bootstrap resampled reconstructions (defaults to 100
in bootstrap)
sdn : float, optional
standard deviation of the noise to add (defaults to 0 in bootstrap)
viz : bool, optional
indicator of whether to generate colorized visualizations
(defaults to False in bootstrap)
"""
# Obtain the size of the input image.
with Image.open(filein) as img:
n, m = img.size
# Select which frequencies to retain.
maski = set(
np.floor(m * np.random.uniform(size=round(m * subsampling_factor))))
mask = np.asarray([False] * m, dtype=bool)
for i in maski:
mask[int(i)] = True
# Make the optimization well-posed by including the zero frequency.
mask[0] = True
# Include all low frequencies.
low = round(math.sqrt(2. * m))
for k in range(low):
mask[k] = True
mask[-k] = True
# Generate bootstrap plots.
loss, losses, rsse_estimated, rsse_blurred = bootstrap(
filein, fileout, subsampling_factor, mask, low, recon, recon_args,
n_resamps, sdn, viz)
# Display the losses.
print('loss = {}'.format(loss))
print('losses = {}'.format(losses))
# Display the estimated root-sum-square errors.
print('Frobenius norm of the bootstrap = {}'.format(rsse_estimated))
print('Frobenius norm of the blurred bootstrap = {}'.format(rsse_blurred))