in svoice/solver.py [0:0]
def _run_one_epoch(self, epoch, cross_valid=False):
total_loss = 0
data_loader = self.tr_loader if not cross_valid else self.cv_loader
# get a different order for distributed training, otherwise this will get ignored
data_loader.epoch = epoch
label = ["Train", "Valid"][cross_valid]
name = label + f" | Epoch {epoch + 1}"
logprog = LogProgress(logger, data_loader,
updates=self.num_prints, name=name)
for i, data in enumerate(logprog):
mixture, lengths, sources = [x.to(self.device) for x in data]
estimate_source = self.dmodel(mixture)
# only eval last layer
if cross_valid:
estimate_source = estimate_source[-1:]
loss = 0
cnt = len(estimate_source)
# apply a loss function after each layer
with torch.autograd.set_detect_anomaly(True):
for c_idx, est_src in enumerate(estimate_source):
coeff = ((c_idx+1)*(1/cnt))
loss_i = 0
# SI-SNR loss
sisnr_loss, snr, est_src, reorder_est_src = cal_loss(
sources, est_src, lengths)
loss += (coeff * sisnr_loss)
loss /= len(estimate_source)
if not cross_valid:
# optimize model in training mode
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.max_norm)
self.optimizer.step()
total_loss += loss.item()
logprog.update(loss=format(total_loss / (i + 1), ".5f"))
# Just in case, clear some memory
del loss, estimate_source
return distrib.average([total_loss / (i + 1)], i + 1)[0]