def _run_one_epoch()

in denoiser/solver.py [0:0]


    def _run_one_epoch(self, epoch, cross_valid=False):
        total_loss = 0
        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        # get a different order for distributed training, otherwise this will get ignored
        data_loader.epoch = epoch

        label = ["Train", "Valid"][cross_valid]
        name = label + f" | Epoch {epoch + 1}"
        logprog = LogProgress(logger, data_loader, updates=self.num_prints, name=name)
        for i, data in enumerate(logprog):
            noisy, clean = [x.to(self.device) for x in data]
            if not cross_valid:
                sources = torch.stack([noisy - clean, clean])
                sources = self.augment(sources)
                noise, clean = sources
                noisy = noise + clean
            estimate = self.dmodel(noisy)
            # apply a loss function after each layer
            with torch.autograd.set_detect_anomaly(True):
                if self.args.loss == 'l1':
                    loss = F.l1_loss(clean, estimate)
                elif self.args.loss == 'l2':
                    loss = F.mse_loss(clean, estimate)
                elif self.args.loss == 'huber':
                    loss = F.smooth_l1_loss(clean, estimate)
                else:
                    raise ValueError(f"Invalid loss {self.args.loss}")
                # MultiResolution STFT loss
                if self.args.stft_loss:
                    sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1))
                    loss += sc_loss + mag_loss

                # optimize model in training mode
                if not cross_valid:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            total_loss += loss.item()
            logprog.update(loss=format(total_loss / (i + 1), ".5f"))
            # Just in case, clear some memory
            del loss, estimate
        return distrib.average([total_loss / (i + 1)], i + 1)[0]