in fastmri/pl_modules/mri_module.py [0:0]
def validation_epoch_end(self, val_logs):
# aggregate losses
losses = []
mse_vals = defaultdict(dict)
target_norms = defaultdict(dict)
ssim_vals = defaultdict(dict)
max_vals = dict()
# use dict updates to handle duplicate slices
for val_log in val_logs:
losses.append(val_log["val_loss"].view(-1))
for k in val_log["mse_vals"].keys():
mse_vals[k].update(val_log["mse_vals"][k])
for k in val_log["target_norms"].keys():
target_norms[k].update(val_log["target_norms"][k])
for k in val_log["ssim_vals"].keys():
ssim_vals[k].update(val_log["ssim_vals"][k])
for k in val_log["max_vals"]:
max_vals[k] = val_log["max_vals"][k]
# check to make sure we have all files in all metrics
assert (
mse_vals.keys()
== target_norms.keys()
== ssim_vals.keys()
== max_vals.keys()
)
# apply means across image volumes
metrics = {"nmse": 0, "ssim": 0, "psnr": 0}
local_examples = 0
for fname in mse_vals.keys():
local_examples = local_examples + 1
mse_val = torch.mean(
torch.cat([v.view(-1) for _, v in mse_vals[fname].items()])
)
target_norm = torch.mean(
torch.cat([v.view(-1) for _, v in target_norms[fname].items()])
)
metrics["nmse"] = metrics["nmse"] + mse_val / target_norm
metrics["psnr"] = (
metrics["psnr"]
+ 20
* torch.log10(
torch.tensor(
max_vals[fname], dtype=mse_val.dtype, device=mse_val.device
)
)
- 10 * torch.log10(mse_val)
)
metrics["ssim"] = metrics["ssim"] + torch.mean(
torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()])
)
# reduce across ddp via sum
metrics["nmse"] = self.NMSE(metrics["nmse"])
metrics["ssim"] = self.SSIM(metrics["ssim"])
metrics["psnr"] = self.PSNR(metrics["psnr"])
tot_examples = self.TotExamples(torch.tensor(local_examples))
val_loss = self.ValLoss(torch.sum(torch.cat(losses)))
tot_slice_examples = self.TotSliceExamples(
torch.tensor(len(losses), dtype=torch.float)
)
self.log("validation_loss", val_loss / tot_slice_examples, prog_bar=True)
for metric, value in metrics.items():
self.log(f"val_metrics/{metric}", value / tot_examples)