in demucs/solver.py [0:0]
def train(self):
# Optimizing the model
if self.history:
logger.info("Replaying metrics from previous run")
for epoch, metrics in enumerate(self.history):
formatted = self._format_train(metrics['train'])
logger.info(
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
formatted = self._format_train(metrics['valid'])
logger.info(
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
if 'test' in metrics:
formatted = self._format_test(metrics['test'])
if formatted:
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
epoch = 0
for epoch in range(len(self.history), self.args.epochs):
# Train one epoch
self.model.train() # Turn on BatchNorm & Dropout
metrics = {}
logger.info('-' * 70)
logger.info("Training...")
metrics['train'] = self._run_one_epoch(epoch)
formatted = self._format_train(metrics['train'])
logger.info(
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
# Cross validation
logger.info('-' * 70)
logger.info('Cross validation...')
self.model.eval() # Turn off Batchnorm & Dropout
with torch.no_grad():
valid = self._run_one_epoch(epoch, train=False)
bvalid = valid
bname = 'main'
state = states.copy_state(self.model.state_dict())
metrics['valid'] = {}
metrics['valid']['main'] = valid
key = self.args.test.metric
for kind, emas in self.emas.items():
for k, ema in enumerate(emas):
with ema.swap():
valid = self._run_one_epoch(epoch, train=False)
name = f'ema_{kind}_{k}'
metrics['valid'][name] = valid
a = valid[key]
b = bvalid[key]
if key.startswith('nsdr'):
a = -a
b = -b
if a < b:
bvalid = valid
state = ema.state
bname = name
metrics['valid'].update(bvalid)
metrics['valid']['bname'] = bname
valid_loss = metrics['valid'][key]
mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
if key.startswith('nsdr'):
best_loss = max(mets)
else:
best_loss = min(mets)
metrics['valid']['best'] = best_loss
if self.args.svd.penalty > 0:
kw = dict(self.args.svd)
kw.pop('penalty')
with torch.no_grad():
penalty = svd_penalty(self.model, exact=True, **kw)
metrics['valid']['penalty'] = penalty
formatted = self._format_train(metrics['valid'])
logger.info(
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
# Save the best model
if valid_loss == best_loss or self.args.dset.train_valid:
logger.info(bold('New best valid loss %.4f'), valid_loss)
self.best_state = states.copy_state(state)
self.best_changed = True
# Eval model every `test.every` epoch or on last epoch
should_eval = (epoch + 1) % self.args.test.every == 0
is_last = epoch == self.args.epochs - 1
reco = metrics['valid']['main']['reco']
# Tries to detect divergence in a reliable way and finish job
# not to waste compute.
div = epoch >= 180 and reco > 0.18
div = div or epoch >= 100 and reco > 0.25
div = div and self.args.optim.loss == 'l1'
if div:
logger.warning("Finishing training early because valid loss is too high.")
is_last = True
if should_eval or is_last:
# Evaluate on the testset
logger.info('-' * 70)
logger.info('Evaluating on the test set...')
# We switch to the best known model for testing
if self.args.test.best:
state = self.best_state
else:
state = states.copy_state(self.model.state_dict())
compute_sdr = self.args.test.sdr and is_last
with states.swap_state(self.model, state):
with torch.no_grad():
metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
formatted = self._format_test(metrics['test'])
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
self.link.push_metrics(metrics)
if distrib.rank == 0:
# Save model each epoch
self._serialize(epoch)
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
if is_last:
break