in bootstrap2.py [0:0]
def bootstrap2(filein, fileout, subsampling_factor, angles, low, recon,
recon_args, n_resamps=None, sdn=None, viz=None):
"""
plots thrice the average of bootstrapped errors in reconstruction
Plots and saves thrice the average of the differences between
a reconstruction of the original image in filein and n_resamps bootstrap
reconstructions via recon applied to the k-space subsamplings specified
by angles fed into radialines and by other masks generated similarly
(retaining each radial "line" with probability subsampling_factor, then
adding all frequencies between -low to low in both directions), corrupting
the k-space values with independent and identically distributed centered
complex Gaussian noise whose standard deviation is sdn*sqrt(2)
(sdn=0 if not provided explicitly).
Setting viz to be True yields colorized visualizations, too, including
the error estimates overlaid over the reconstruction, the error estimates
blurred overlaid over the reconstruction, the error estimates blurred,
the error estimates subtracted from the reconstruction, the error estimates
saturating the reconstruction in hue-saturation-value (HSV) color space,
and the error estimates interpolating the reconstruction in HSV space.
The calling sequence of recon must be (m, n, f, mask, **recon_args),
where filein contains an m x n image, f is the image in k-space subsampled
to the mask, mask is the return from calls to radialines (with angles),
supplemented by all frequencies between -low to low in both directions, and
**recon_args is the unpacking of recon_args. The function recon must return
a torch.Tensor (the reconstruction) and a float (the corresponding loss).
Parameters
----------
filein : str
path to the file containing the image to be processed (the path may be
relative or absolute)
fileout : str
path to the file to which the plots will be saved (the path may be
relative or absolute)
subsampling_factor : float
probability of retaining a radial "line" in the subsampling masks
angles : list of float
angles of the radial "lines" in the mask that radialines will construct
low : int
bandwidth of low frequencies included in mask (between -low to low
in both the horizontal and vertical directions)
recon : function
returns the reconstructed image
recon_args : dict
keyword arguments for recon
n_resamps : int, optional
number of bootstrap resampled reconstructions (defaults to 100)
sdn : float, optional
standard deviation of the noise to add (defaults to 0)
viz : bool, optional
indicator of whether to generate colorized visualizations
(defaults to False)
Returns
-------
float
loss for the reconstruction using the original angles
list of float
losses for the reconstructions using other, randomly generated masks
float
square root of the sum of the square of the estimated errors
float
square root of the sum of the square of the estimated errors blurred
"""
# Set default parameters.
if n_resamps is None:
n_resamps = 100
if sdn is None:
sdn = 0
if viz is None:
viz = False
# Read the image from disk.
with Image.open(filein) as img:
f_orig = np.array(img).astype(np.float64) / 255.
m = f_orig.shape[0]
n = f_orig.shape[1]
# Fourier transform the image.
ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
# Add noise.
ff_noisy = ff_orig.copy()
ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
# Select which frequencies to retain.
mask = radialines.randradialineset(m, n, angles)
# Include all low frequencies.
for km in range(low):
for kn in range(low):
mask[km, kn] = True
mask[m - 1 - km, kn] = True
mask[km, n - 1 - kn] = True
mask[m - 1 - km, n - 1 - kn] = True
# Subsample the noisy Fourier transform of the original image.
f = ctorch.from_numpy(ff_noisy[mask]).cuda()
logging.info(
'computing bootstrap2 resamplings -- all {}'.format(n_resamps))
# Perform the reconstruction using the mask.
reconf, lossf = recon(m, n, f, mask, **recon_args)
reconf = reconf.cpu().numpy()
# Fourier transform the reconstruction.
freconf = np.fft.fft2(reconf) / np.sqrt(m * n)
# Perform the reconstruction resampling new masks and samples in k-space.
recons = np.ndarray((n_resamps, m, n))
loss = []
for k in range(n_resamps):
# Select which frequencies to retain.
angles1 = np.random.uniform(
low=0, high=(2 * np.pi),
size=round(2 * (m + n) * subsampling_factor))
mask1 = radialines.randradialineset(m, n, angles1)
# Include all low frequencies.
for km in range(low):
for kn in range(low):
mask1[km, kn] = True
mask1[m - 1 - km, kn] = True
mask1[km, n - 1 - kn] = True
mask1[m - 1 - km, n - 1 - kn] = True
# Subsample the Fourier transform of the reconstruction.
f1 = ctorch.from_numpy(freconf[mask1]).cuda()
# Reconstruct the image from the subsampled data.
recon1, loss1 = recon(m, n, f1, mask1, **recon_args)
recon1 = recon1.cpu().numpy()
# Record the results.
recons[k, :, :] = recon1
loss.append(loss1)
# Calculate the sum of the bootstrap differences.
sumboo = np.sum(recons - reconf, axis=0)
scaled = sumboo * 3 / n_resamps
# Blur the error estimates.
sigma = 1
blurred = skimage.filters.gaussian(scaled, sigma=sigma)
rsse_estimated = np.linalg.norm(scaled, ord='fro')
rsse_blurred = np.linalg.norm(blurred, ord='fro')
# Plot errors.
# Remove the ticks and spines on the axes.
matplotlib.rcParams['xtick.top'] = False
matplotlib.rcParams['xtick.bottom'] = False
matplotlib.rcParams['ytick.left'] = False
matplotlib.rcParams['ytick.right'] = False
matplotlib.rcParams['xtick.labeltop'] = False
matplotlib.rcParams['xtick.labelbottom'] = False
matplotlib.rcParams['ytick.labelleft'] = False
matplotlib.rcParams['ytick.labelright'] = False
matplotlib.rcParams['axes.spines.top'] = False
matplotlib.rcParams['axes.spines.bottom'] = False
matplotlib.rcParams['axes.spines.left'] = False
matplotlib.rcParams['axes.spines.right'] = False
# Configure the colormaps.
kwargs01 = dict(cmap='gray',
norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
kwargs11 = dict(cmap='gray',
norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
# Separate the suffix (filetype) from the rest of the filename.
suffix = '.' + fileout.split('.')[-1]
rest = fileout[:-len(suffix)]
assert fileout == rest + suffix
# Plot the original.
plt.figure(figsize=(5.5, 5.5))
plt.title('Original')
plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
# Plot the reconstruction from the original mask provided.
plt.figure(figsize=(5.5, 5.5))
plt.title('Reconstruction')
plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
# Plot the difference from the original.
plt.figure(figsize=(5.5, 5.5))
plt.title('Error of Reconstruction')
plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
# Plot thrice the average of the bootstrap differences.
plt.figure(figsize=(5.5, 5.5))
plt.title('Bootstrap')
plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
plt.savefig(rest + '_bootstrap' + suffix, bbox_inches='tight')
if viz:
# Plot the reconstruction minus the bootstrap difference.
plt.figure(figsize=(5.5, 5.5))
plt.title('Reconstruction \u2013 Bootstrap')
plt.imshow(np.clip(reconf - scaled, 0, 1), **kwargs01)
plt.savefig(rest + '_corrected' + suffix, bbox_inches='tight')
# Overlay the error estimates on the reconstruction.
plt.figure(figsize=(5.5, 5.5))
threshold = np.abs(scaled).flatten()
threshold = np.sort(threshold)
maxthresh = threshold[-1]
threshold = threshold[round(0.99 * threshold.size)]
hue = 2. / 3 + (scaled / maxthresh) / 4 * 2 / 3
saturation = np.abs(scaled) > threshold
value = reconf * (1 - saturation) + saturation
hsv = np.dstack((hue, saturation, value))
rgb = hsv_to_rgb(hsv)
plt.title('Errors Over a Threshold Overlaid')
plt.imshow(np.clip(rgb, 0, 1))
plt.savefig(rest + '_overlaid' + suffix, bbox_inches='tight')
# Overlay the blurred error estimates on the reconstruction.
plt.figure(figsize=(5.5, 5.5))
threshold = np.abs(blurred).flatten()
threshold = np.sort(threshold)
maxthresh = threshold[-1]
threshold = threshold[round(0.99 * threshold.size)]
hue = 2. / 3 + (blurred / maxthresh) / 4 * 2 / 3
saturation = np.abs(blurred) > threshold
value = reconf * (1 - saturation) + saturation
hsv = np.dstack((hue, saturation, value))
rgb = hsv_to_rgb(hsv)
plt.title('Blurred Errors Over a Threshold Overlaid')
plt.imshow(np.clip(rgb, 0, 1))
plt.savefig(rest + '_blurred_overlaid' + suffix, bbox_inches='tight')
# Plot a bootstrap-saturated reconstruction.
plt.figure(figsize=(5.5, 5.5))
hue = (1 - np.sign(scaled)) / 4 * 2 / 3
saturation = np.abs(scaled)
saturation = saturation / np.max(saturation)
value = np.clip(reconf, 0, 1)
hsv = np.dstack((hue, saturation, value))
rgb = hsv_to_rgb(hsv)
plt.title('Bootstrap-Saturated Reconstruction')
plt.imshow(np.clip(rgb, 0, 1))
plt.savefig(rest + '_saturated' + suffix, bbox_inches='tight')
# Plot a bootstrap-interpolated reconstruction.
plt.figure(figsize=(5.5, 5.5))
hue = 7. / 12 + np.sign(scaled) * 3 / 12
saturation = np.abs(scaled)
saturation = saturation / np.max(saturation)
value = np.clip(reconf, 0, 1)
hsv = np.dstack((hue, saturation, value))
rgb = hsv_to_rgb(hsv)
plt.title('Bootstrap-Interpolated Reconstruction')
plt.imshow(np.clip(rgb, 0, 1))
plt.savefig(rest + '_interpolated' + suffix, bbox_inches='tight')
# Plot the blurred bootstrap.
plt.figure(figsize=(5.5, 5.5))
plt.title('Blurred Bootstrap')
plt.imshow(np.clip(blurred, -1, 1), **kwargs11)
plt.savefig(rest + '_blurred' + suffix, bbox_inches='tight')
return lossf, loss, rsse_estimated, rsse_blurred