in scripts/image_nll.py [0:0]
def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised):
all_bpd = []
all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
num_complete = 0
while num_complete < num_samples:
batch, model_kwargs = next(data)
batch = batch.to(dist_util.dev())
model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
minibatch_metrics = diffusion.calc_bpd_loop(
model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
for key, term_list in all_metrics.items():
terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
dist.all_reduce(terms)
term_list.append(terms.detach().cpu().numpy())
total_bpd = minibatch_metrics["total_bpd"]
total_bpd = total_bpd.mean() / dist.get_world_size()
dist.all_reduce(total_bpd)
all_bpd.append(total_bpd.item())
num_complete += dist.get_world_size() * batch.shape[0]
logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}")
if dist.get_rank() == 0:
for name, terms in all_metrics.items():
out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz")
logger.log(f"saving {name} terms to {out_path}")
np.savez(out_path, np.mean(np.stack(terms), axis=0))
dist.barrier()
logger.log("evaluation complete")