in fastmri/pl_modules/mri_module.py [0:0]
def validation_step_end(self, val_logs):
# check inputs
for k in (
"batch_idx",
"fname",
"slice_num",
"max_value",
"output",
"target",
"val_loss",
):
if k not in val_logs.keys():
raise RuntimeError(
f"Expected key {k} in dict returned by validation_step."
)
if val_logs["output"].ndim == 2:
val_logs["output"] = val_logs["output"].unsqueeze(0)
elif val_logs["output"].ndim != 3:
raise RuntimeError("Unexpected output size from validation_step.")
if val_logs["target"].ndim == 2:
val_logs["target"] = val_logs["target"].unsqueeze(0)
elif val_logs["target"].ndim != 3:
raise RuntimeError("Unexpected output size from validation_step.")
# pick a set of images to log if we don't have one already
if self.val_log_indices is None:
self.val_log_indices = list(
np.random.permutation(len(self.trainer.val_dataloaders[0]))[
: self.num_log_images
]
)
# log images to tensorboard
if isinstance(val_logs["batch_idx"], int):
batch_indices = [val_logs["batch_idx"]]
else:
batch_indices = val_logs["batch_idx"]
for i, batch_idx in enumerate(batch_indices):
if batch_idx in self.val_log_indices:
key = f"val_images_idx_{batch_idx}"
target = val_logs["target"][i].unsqueeze(0)
output = val_logs["output"][i].unsqueeze(0)
error = torch.abs(target - output)
output = output / output.max()
target = target / target.max()
error = error / error.max()
self.log_image(f"{key}/target", target)
self.log_image(f"{key}/reconstruction", output)
self.log_image(f"{key}/error", error)
# compute evaluation metrics
mse_vals = defaultdict(dict)
target_norms = defaultdict(dict)
ssim_vals = defaultdict(dict)
max_vals = dict()
for i, fname in enumerate(val_logs["fname"]):
slice_num = int(val_logs["slice_num"][i].cpu())
maxval = val_logs["max_value"][i].cpu().numpy()
output = val_logs["output"][i].cpu().numpy()
target = val_logs["target"][i].cpu().numpy()
mse_vals[fname][slice_num] = torch.tensor(
evaluate.mse(target, output)
).view(1)
target_norms[fname][slice_num] = torch.tensor(
evaluate.mse(target, np.zeros_like(target))
).view(1)
ssim_vals[fname][slice_num] = torch.tensor(
evaluate.ssim(target[None, ...], output[None, ...], maxval=maxval)
).view(1)
max_vals[fname] = maxval
return {
"val_loss": val_logs["val_loss"],
"mse_vals": dict(mse_vals),
"target_norms": dict(target_norms),
"ssim_vals": dict(ssim_vals),
"max_vals": max_vals,
}