in denoiser/solver.py [0:0]
def train(self):
if self.args.save_again:
self._serialize()
return
# Optimizing the model
if self.history:
logger.info("Replaying metrics from previous run")
for epoch, metrics in enumerate(self.history):
info = " ".join(f"{k.capitalize()}={v:.5f}" for k, v in metrics.items())
logger.info(f"Epoch {epoch + 1}: {info}")
for epoch in range(len(self.history), self.epochs):
# Train one epoch
self.model.train()
start = time.time()
logger.info('-' * 70)
logger.info("Training...")
train_loss = self._run_one_epoch(epoch)
logger.info(
bold(f'Train Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
if self.cv_loader:
# Cross validation
logger.info('-' * 70)
logger.info('Cross validation...')
self.model.eval()
with torch.no_grad():
valid_loss = self._run_one_epoch(epoch, cross_valid=True)
logger.info(
bold(f'Valid Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
else:
valid_loss = 0
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
metrics = {'train': train_loss, 'valid': valid_loss, 'best': best_loss}
# Save the best model
if valid_loss == best_loss:
logger.info(bold('New best valid loss %.4f'), valid_loss)
self.best_state = copy_state(self.model.state_dict())
# evaluate and enhance samples every 'eval_every' argument number of epochs
# also evaluate on last epoch
if ((epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1) and self.tt_loader:
# Evaluate on the testset
logger.info('-' * 70)
logger.info('Evaluating on the test set...')
# We switch to the best known model for testing
with swap_state(self.model, self.best_state):
pesq, stoi = evaluate(self.args, self.model, self.tt_loader)
metrics.update({'pesq': pesq, 'stoi': stoi})
# enhance some samples
logger.info('Enhance and save samples...')
enhance(self.args, self.model, self.samples_dir)
self.history.append(metrics)
info = " | ".join(f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
logger.info('-' * 70)
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
if distrib.rank == 0:
json.dump(self.history, open(self.history_file, "w"), indent=2)
# Save model each epoch
if self.checkpoint:
self._serialize()
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())