in banding_removal/fastmri/visualization_mixin.py [0:0]
def visualize_dev(self, epoch):
self.model.eval()
grid_size = self.args.display_count
if grid_size == 0:
return
logging.debug("Saving visualizations")
images_processed = 0
grid_recons = None
with torch.no_grad():
for batch_idx, batch in enumerate(self.display_loader):
output, target = self.predict(batch)
target = transforms.center_crop_or_pad(target,
(self.args.resolution_height, self.args.resolution_width))
output = transforms.center_crop_or_pad(output,
(self.args.resolution_height, self.args.resolution_width))
if batch_idx == self.example_idx:
logging.info(f"output std: {output.std()} target std {target.std()}")
logging.info(f"output min: {output.min()} target min {target.min()}")
logging.info(f"output max: {output.max()} target max {target.max()}")
logging.debug(f"output (0.1, 1, 10, 90, 99, 99.9): {self.quantiles(output)}")
logging.debug(f"target (0.1, 1, 10, 90, 99, 99.9): {self.quantiles(target)}")
if grid_recons is None:
grid_recons = torch.zeros(grid_size, output.shape[1],
output.shape[2], output.shape[3]).to(self.device)
grid_images = torch.zeros_like(grid_recons)
grid_iffts = torch.zeros_like(grid_recons)
for j in range(output.shape[0]):
if images_processed >= grid_size:
break
grid_recons[images_processed, ...] = output.data[j, ...].float()
grid_images[images_processed, ...] = target.data[j, ...].float()
if self.args.display_ifft:
masked_kspace = batch['input']
ifft_abs = transforms.complex_abs(transforms.ifft2(masked_kspace)).squeeze(0)
masked_image = transforms.root_sum_of_squares(ifft_abs).unsqueeze(0)
masked_image = transforms.center_crop_or_pad(masked_image,
(self.args.resolution_height, self.args.resolution_width))
grid_iffts[images_processed, ...] = masked_image.data[j, ...].float()
images_processed += 1
logging.debug(f"Copying visual images to cpu")
sys.stdout.flush()
grid_recons = grid_recons.cpu()
grid_images = grid_images.cpu()
grid_errors = torch.abs(grid_recons - grid_images)
if self.args.rank == 0: # Only master task does visual
self.save_images(grid_images, 'Target', epoch)
self.save_images(grid_recons, 'Reconstruction', epoch)
self.save_images(grid_errors, 'Error', epoch)
logging.debug(f"Sent images to tensorboard and saved.")
sys.stdout.flush()
if self.args.display_ifft and self.args.rank == 0:
grid_iffts = grid_iffts.cpu()
self.save_images(grid_iffts, 'Ifft', epoch)
image_dir = self.exp_dir / "grids"
image_dir.mkdir(exist_ok=True)
image_blocks = []
losses = {'NMSE': [], 'SSIM': [], 'MSE': []}
for i in range(images_processed):
gtnp = grid_images[i].cpu().numpy()
prednp = grid_recons[i].cpu().numpy()
losses['NMSE'].append(evaluate.nmse(gtnp, prednp))
losses['SSIM'].append(evaluate.ssim(gtnp, prednp))
losses['MSE'].append(evaluate.mse(gtnp, prednp))
gt = grid_images[i]
shift = torch.min(gt)
scale = torch.max(gt - shift)
image_blocks.append((
(grid_images[i] - shift) / scale,
(grid_recons[i] - shift) / scale,
0.5 + 4 * (grid_errors[i] / scale)) +
(((grid_iffts[i] - shift) / scale, ) if self.args.display_ifft else ())
)
grid_pil = image_grid.grid(image_blocks, losses=losses, runinfo=self.runinfo)
grid_path = image_dir / f"epoch{epoch:03}.png"
if self.args.rank == 0:
grid_pil.save(grid_path, format="PNG")
logging.info(f"Saved image grid to {grid_path.resolve()}")
sys.stdout.flush()