def train()

in svoice/solver.py [0:0]


    def train(self):
        # Optimizing the model
        if self.history:
            logger.info("Replaying metrics from previous run")
        for epoch, metrics in enumerate(self.history):
            info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
            logger.info(f"Epoch {epoch}: {info}")

        for epoch in range(len(self.history), self.epochs):
            # Train one epoch
            self.model.train()  # Turn on BatchNorm & Dropout
            start = time.time()
            logger.info('-' * 70)
            logger.info("Training...")
            train_loss = self._run_one_epoch(epoch)
            logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
                             f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))

            # Cross validation
            logger.info('-' * 70)
            logger.info('Cross validation...')
            self.model.eval()  # Turn off Batchnorm & Dropout
            with torch.no_grad():
                valid_loss = self._run_one_epoch(epoch, cross_valid=True)
            logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
                             f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))

            # learning rate scheduling
            if self.sched:
                if self.args.lr_sched == 'plateau':
                    self.sched.step(valid_loss)
                else:
                    self.sched.step()
                logger.info(
                    f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')

            best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
            metrics = {'train': train_loss,
                       'valid': valid_loss, 'best': best_loss}
            # Save the best model
            if valid_loss == best_loss or self.args.keep_last:
                logger.info(bold('New best valid loss %.4f'), valid_loss)
                self.best_state = copy_state(self.model.state_dict())

            # evaluate and separate samples every 'eval_every' argument number of epochs
            # also evaluate on last epoch
            if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
                # Evaluate on the testset
                logger.info('-' * 70)
                logger.info('Evaluating on the test set...')
                # We switch to the best known model for testing
                with swap_state(self.model, self.best_state):
                    sisnr, pesq, stoi = evaluate(
                        self.args, self.model, self.tt_loader, self.args.sample_rate)
                metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})

                # separate some samples
                logger.info('Separate and save samples...')
                separate(self.args, self.model, self.samples_dir)

            self.history.append(metrics)
            info = " | ".join(
                f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
            logger.info('-' * 70)
            logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))

            if distrib.rank == 0:
                json.dump(self.history, open(self.history_file, "w"), indent=2)
                # Save model each epoch
                if self.checkpoint:
                    self._serialize(self.checkpoint)
                    logger.debug("Checkpoint saved to %s",
                                 self.checkpoint.resolve())