in jukebox/train.py [0:0]
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
model.train()
orig_model.train()
if hps.prior:
_print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
else:
_print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")
for i, x in logger.get_range(data_processor.train_loader):
if isinstance(x, (tuple, list)):
x, y = x
else:
y = None
x = x.to('cuda', non_blocking=True)
if y is not None:
y = y.to('cuda', non_blocking=True)
x_in = x = audio_preprocess(x, hps)
log_input_output = (logger.iters % hps.save_iters == 0)
if hps.prior:
forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
else:
forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
# Forward
x_out, loss, _metrics = model(x, **forw_kwargs)
# Backward
loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
scalar=scalar, fp16=hps.fp16, logger=logger)
# Skip step if overflow
grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
zero_grad(orig_model)
continue
# Step opt. Divide by scale to include clipping and fp16 scaling
logger.step()
opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
zero_grad(orig_model)
lr = hps.lr if shd is None else shd.get_lr()[0]
if shd is not None: shd.step()
if ema is not None: ema.step()
next_lr = hps.lr if shd is None else shd.get_lr()[0]
finished_training = (next_lr == 0.0)
# Logging
for key, val in _metrics.items():
_metrics[key] = val.item()
_metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
_metrics["gn"] = grad_norm
_metrics["lr"] = lr
_metrics["lg_loss_scale"] = np.log2(scale)
# Average and log
for key, val in _metrics.items():
_metrics[key] = metrics.update(key, val, x.shape[0])
if logger.iters % hps.log_steps == 0:
logger.add_scalar(key, _metrics[key])
# Save checkpoint
with t.no_grad():
if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
if ema is not None: ema.swap()
orig_model.eval()
name = 'latest' if hps.prior else f'step_{logger.iters}'
if dist.get_rank() % 8 == 0:
save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
orig_model.train()
if ema is not None: ema.swap()
# Sample
with t.no_grad():
if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
if hps.prior:
sample_prior(orig_model, ema, logger, x_in, y, hps)
# Input/Output
with t.no_grad():
if log_input_output:
log_inputs(orig_model, logger, x_in, y, x_out, hps)
logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
if finished_training:
dist.barrier()
exit()
logger.close_range()
return {key: metrics.avg(key) for key in _metrics.keys()}