in banding_removal/fastmri/training_loop_mixin.py [0:0]
def compute_stats(self, epoch, loader, setname):
""" This is separate from stats mainly for distributed support"""
args = self.args
self.model.eval()
ndevbatches = len(self.dev_loader)
logging.info(f"Evaluating {ndevbatches} batches ...")
recons, gts = defaultdict(list), defaultdict(list)
acquisition_machine_by_fname = dict()
with torch.no_grad():
for batch_idx, batch in enumerate(self.dev_loader):
progress = epoch + batch_idx/ndevbatches
logging_epoch = batch_idx % args.log_interval == 0
logging_epoch_info = batch_idx % (2 * args.log_interval) == 0
log = logging.info if logging_epoch_info else logging.debug
self.start_of_test_batch_hook(progress, logging_epoch)
batch = self.preprocess_data(batch)
output, target = self.predict(batch)
output = self.unnorm(output, batch)
target = self.unnorm(target, batch)
fname, slice = batch.fname, batch.slice
for i in range(output.shape[0]):
slice_cpu = slice[i].item()
recons[fname[i]].append((slice_cpu, output[i].float().cpu().numpy()))
gts[fname[i]].append((slice_cpu, target[i].float().cpu().numpy()))
acquisition_type = batch.attrs_dict['acquisition'][i]
machine_type = batch.attrs_dict['system'][i]
acquisition_machine_by_fname[fname[i]] = machine_type + '_' + acquisition_type
if logging_epoch or batch_idx == ndevbatches-1:
gpu_memory_gb = torch.cuda.memory_allocated()/1000000000
host_memory_gb = utils.host_memory_usage_in_gb()
log(f"Evaluated {batch_idx+1} of {ndevbatches} (GPU Mem: {gpu_memory_gb:2.3f}gb Host Mem: {gpu_memory_gb:2.3f}gb)")
sys.stdout.flush()
if self.args.debug_epoch_stats:
break
del output, target, batch
logging.debug(f"Finished evaluating")
self.end_of_test_epoch_hook()
recons = {
fname: np.stack([pred for _, pred in sorted(slice_preds)])
for fname, slice_preds in recons.items()
}
gts = {
fname: np.stack([pred for _, pred in sorted(slice_preds)])
for fname, slice_preds in gts.items()
}
nmse, psnr, ssims = [], [], []
ssim_for_acquisition_machine = defaultdict(list)
recon_keys = list(recons.keys()).copy()
for fname in recon_keys:
pred_or, gt_or = recons[fname].squeeze(1), gts[fname].squeeze(1)
pred, gt = transforms.center_crop_to_smallest(pred_or, gt_or)
del pred_or, gt_or
ssim = evaluate.ssim(gt, pred)
acquisition_machine = acquisition_machine_by_fname[fname]
ssim_for_acquisition_machine[acquisition_machine].append(ssim)
ssims.append(ssim)
nmse.append(evaluate.nmse(gt, pred))
psnr.append(evaluate.psnr(gt, pred))
del gt, pred
del recons[fname], gts[fname]
if len(nmse) == 0:
nmse.append(0)
ssims.append(0)
psnr.append(0)
min_vol_ssim = np.argmin(ssims)
min_vol = str(recon_keys[min_vol_ssim])
logging.info(f"Min vol ssims: {min_vol}")
sys.stdout.flush()
del recons, gts
acquisition_machine_losses = dict.fromkeys(self.dev_data.system_acquisitions, 0)
for key, value in ssim_for_acquisition_machine.items():
acquisition_machine_losses[key] = np.mean(value)
losses = {'NMSE': np.mean(nmse),
'PSNR': np.mean(psnr),
'SSIM': np.mean(ssims),
'SSIM_var': np.var(ssims),
'SSIM_min': np.min(ssims),
**acquisition_machine_losses}
return losses