in denoiser/solver.py [0:0]
def _run_one_epoch(self, epoch, cross_valid=False):
total_loss = 0
data_loader = self.tr_loader if not cross_valid else self.cv_loader
# get a different order for distributed training, otherwise this will get ignored
data_loader.epoch = epoch
label = ["Train", "Valid"][cross_valid]
name = label + f" | Epoch {epoch + 1}"
logprog = LogProgress(logger, data_loader, updates=self.num_prints, name=name)
for i, data in enumerate(logprog):
noisy, clean = [x.to(self.device) for x in data]
if not cross_valid:
sources = torch.stack([noisy - clean, clean])
sources = self.augment(sources)
noise, clean = sources
noisy = noise + clean
estimate = self.dmodel(noisy)
# apply a loss function after each layer
with torch.autograd.set_detect_anomaly(True):
if self.args.loss == 'l1':
loss = F.l1_loss(clean, estimate)
elif self.args.loss == 'l2':
loss = F.mse_loss(clean, estimate)
elif self.args.loss == 'huber':
loss = F.smooth_l1_loss(clean, estimate)
else:
raise ValueError(f"Invalid loss {self.args.loss}")
# MultiResolution STFT loss
if self.args.stft_loss:
sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1))
loss += sc_loss + mag_loss
# optimize model in training mode
if not cross_valid:
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
logprog.update(loss=format(total_loss / (i + 1), ".5f"))
# Just in case, clear some memory
del loss, estimate
return distrib.average([total_loss / (i + 1)], i + 1)[0]