def compute_stats()

in banding_removal/fastmri/training_loop_mixin.py [0:0]


    def compute_stats(self, epoch, loader, setname):
        """ This is separate from stats mainly for distributed support"""
        args = self.args
        self.model.eval()
        ndevbatches = len(self.dev_loader)
        logging.info(f"Evaluating {ndevbatches} batches ...")

        recons, gts = defaultdict(list), defaultdict(list)
        acquisition_machine_by_fname = dict()
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.dev_loader):
                progress = epoch + batch_idx/ndevbatches
                logging_epoch = batch_idx % args.log_interval == 0
                logging_epoch_info = batch_idx % (2 * args.log_interval) == 0
                log = logging.info if logging_epoch_info else logging.debug

                self.start_of_test_batch_hook(progress, logging_epoch)

                batch = self.preprocess_data(batch)
                output, target = self.predict(batch)
                output = self.unnorm(output, batch)
                target = self.unnorm(target, batch)
                fname, slice = batch.fname, batch.slice

                for i in range(output.shape[0]):
                    slice_cpu = slice[i].item()
                    recons[fname[i]].append((slice_cpu, output[i].float().cpu().numpy()))
                    gts[fname[i]].append((slice_cpu, target[i].float().cpu().numpy()))

                    acquisition_type = batch.attrs_dict['acquisition'][i]
                    machine_type = batch.attrs_dict['system'][i]
                    acquisition_machine_by_fname[fname[i]] = machine_type + '_' + acquisition_type

                if logging_epoch or batch_idx == ndevbatches-1:
                    gpu_memory_gb = torch.cuda.memory_allocated()/1000000000
                    host_memory_gb = utils.host_memory_usage_in_gb()
                    log(f"Evaluated {batch_idx+1} of {ndevbatches} (GPU Mem: {gpu_memory_gb:2.3f}gb Host Mem: {gpu_memory_gb:2.3f}gb)")
                    sys.stdout.flush()

                if self.args.debug_epoch_stats:
                    break
                del output, target, batch

            logging.debug(f"Finished evaluating")
            self.end_of_test_epoch_hook()

            recons = {
                fname: np.stack([pred for _, pred in sorted(slice_preds)])
                for fname, slice_preds in recons.items()
            }
            gts = {
                fname: np.stack([pred for _, pred in sorted(slice_preds)])
                for fname, slice_preds in gts.items()
            }

            nmse, psnr, ssims = [], [], []
            ssim_for_acquisition_machine = defaultdict(list)
            recon_keys = list(recons.keys()).copy()
            for fname in recon_keys:
                pred_or, gt_or = recons[fname].squeeze(1), gts[fname].squeeze(1)
                pred, gt = transforms.center_crop_to_smallest(pred_or, gt_or)
                del pred_or, gt_or

                ssim = evaluate.ssim(gt, pred)
                acquisition_machine = acquisition_machine_by_fname[fname]
                ssim_for_acquisition_machine[acquisition_machine].append(ssim)
                ssims.append(ssim)
                nmse.append(evaluate.nmse(gt, pred))
                psnr.append(evaluate.psnr(gt, pred))
                del gt, pred
                del recons[fname], gts[fname]

            if len(nmse) == 0:
               nmse.append(0)
               ssims.append(0)
               psnr.append(0)

            min_vol_ssim = np.argmin(ssims)
            min_vol = str(recon_keys[min_vol_ssim])
            logging.info(f"Min vol ssims: {min_vol}")
            sys.stdout.flush()

            del recons, gts

            acquisition_machine_losses = dict.fromkeys(self.dev_data.system_acquisitions, 0)
            for key, value in ssim_for_acquisition_machine.items():
                acquisition_machine_losses[key] = np.mean(value)

            losses = {'NMSE': np.mean(nmse),
                      'PSNR': np.mean(psnr),
                      'SSIM': np.mean(ssims),
                      'SSIM_var': np.var(ssims),
                      'SSIM_min': np.min(ssims),
                      **acquisition_machine_losses}

        return losses