in jukebox/train.py [0:0]
def evaluate(model, orig_model, logger, metrics, data_processor, hps):
model.eval()
orig_model.eval()
if hps.prior:
_print_keys = dict(l="loss", bpd="bpd")
else:
_print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss")
with t.no_grad():
for i, x in logger.get_range(data_processor.test_loader):
if isinstance(x, (tuple, list)):
x, y = x
else:
y = None
x = x.to('cuda', non_blocking=True)
if y is not None:
y = y.to('cuda', non_blocking=True)
x_in = x = audio_preprocess(x, hps)
log_input_output = (i==0)
if hps.prior:
forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
else:
forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
x_out, loss, _metrics = model(x, **forw_kwargs)
# Logging
for key, val in _metrics.items():
_metrics[key] = val.item()
_metrics["loss"] = loss = loss.item() # Make sure to call to free graph
# Average and log
for key, val in _metrics.items():
_metrics[key] = metrics.update(f"test_{key}", val, x.shape[0])
with t.no_grad():
if log_input_output:
log_inputs(orig_model, logger, x_in, y, x_out, hps)
logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
for key, val in _metrics.items():
logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}"))
logger.close_range()
return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()}