in fairnr/models/fairnr_model.py [0:0]
def add_eval_scores(self, logging_output, sample, output, criterion, scores=['ssim', 'psnr', 'lpips'], outdir=None):
predicts, targets = output['colors'], sample['colors']
ssims, psnrs, lpips, rmses = [], [], [], []
for s in range(predicts.size(0)):
for v in range(predicts.size(1)):
width = int(sample['size'][s, v][1])
p = recover_image(predicts[s, v], width=width, min_val=float(self.args.min_color))
t = recover_image(targets[s, v], width=width, min_val=float(self.args.min_color))
pn, tn = p.numpy(), t.numpy()
p, t = p.to(predicts.device), t.to(targets.device)
if 'ssim' in scores:
ssims += [skimage.metrics.structural_similarity(pn, tn, multichannel=True, data_range=1)]
if 'psnr' in scores:
psnrs += [skimage.metrics.peak_signal_noise_ratio(pn, tn, data_range=1)]
if 'lpips' in scores and hasattr(criterion, 'lpips'):
with torch.no_grad():
lpips += [criterion.lpips(
2 * p.unsqueeze(-1).permute(3,2,0,1) - 1,
2 * t.unsqueeze(-1).permute(3,2,0,1) - 1).item()]
if 'depths' in sample:
td = sample['depths'][sample['depths'] > 0]
pd = output['depths'][sample['depths'] > 0]
rmses += [torch.sqrt(((td - pd) ** 2).mean()).item()]
if outdir is not None:
def imsave(filename, image):
imageio.imsave(os.path.join(outdir, filename), (image * 255).astype('uint8'))
figname = '-{:03d}_{:03d}.png'.format(sample['id'][s], sample['view'][s, v])
imsave('output' + figname, pn)
imsave('target' + figname, tn)
imsave('normal' + figname, recover_image(compute_normal_map(
sample['ray_start'][s, v].float(), sample['ray_dir'][s, v].float(),
output['depths'][s, v].float(), sample['extrinsics'][s, v].float().inverse(), width=width),
min_val=-1, max_val=1, width=width).numpy())
if 'featn2' in output:
imsave('featn2' + figname, output['featn2'][s, v].cpu().numpy())
if 'voxel' in output:
imsave('voxel' + figname, output['voxel'][s, v].cpu().numpy())
if len(ssims) > 0:
logging_output['ssim_loss'] = np.mean(ssims)
if len(psnrs) > 0:
logging_output['psnr_loss'] = np.mean(psnrs)
if len(lpips) > 0:
logging_output['lpips_loss'] = np.mean(lpips)
if len(rmses) > 0:
logging_output['rmses_loss'] = np.mean(rmses)