in svoice/solver.py [0:0]
def train(self):
# Optimizing the model
if self.history:
logger.info("Replaying metrics from previous run")
for epoch, metrics in enumerate(self.history):
info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
logger.info(f"Epoch {epoch}: {info}")
for epoch in range(len(self.history), self.epochs):
# Train one epoch
self.model.train() # Turn on BatchNorm & Dropout
start = time.time()
logger.info('-' * 70)
logger.info("Training...")
train_loss = self._run_one_epoch(epoch)
logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
# Cross validation
logger.info('-' * 70)
logger.info('Cross validation...')
self.model.eval() # Turn off Batchnorm & Dropout
with torch.no_grad():
valid_loss = self._run_one_epoch(epoch, cross_valid=True)
logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
# learning rate scheduling
if self.sched:
if self.args.lr_sched == 'plateau':
self.sched.step(valid_loss)
else:
self.sched.step()
logger.info(
f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
metrics = {'train': train_loss,
'valid': valid_loss, 'best': best_loss}
# Save the best model
if valid_loss == best_loss or self.args.keep_last:
logger.info(bold('New best valid loss %.4f'), valid_loss)
self.best_state = copy_state(self.model.state_dict())
# evaluate and separate samples every 'eval_every' argument number of epochs
# also evaluate on last epoch
if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
# Evaluate on the testset
logger.info('-' * 70)
logger.info('Evaluating on the test set...')
# We switch to the best known model for testing
with swap_state(self.model, self.best_state):
sisnr, pesq, stoi = evaluate(
self.args, self.model, self.tt_loader, self.args.sample_rate)
metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})
# separate some samples
logger.info('Separate and save samples...')
separate(self.args, self.model, self.samples_dir)
self.history.append(metrics)
info = " | ".join(
f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
logger.info('-' * 70)
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
if distrib.rank == 0:
json.dump(self.history, open(self.history_file, "w"), indent=2)
# Save model each epoch
if self.checkpoint:
self._serialize(self.checkpoint)
logger.debug("Checkpoint saved to %s",
self.checkpoint.resolve())