in demucs/solver.py [0:0]
def _run_one_epoch(self, epoch, train=True):
args = self.args
data_loader = self.loaders['train'] if train else self.loaders['valid']
# get a different order for distributed training, otherwise this will get ignored
data_loader.sampler.epoch = epoch
label = ["Valid", "Train"][train]
name = label + f" | Epoch {epoch + 1}"
total = len(data_loader)
if args.max_batches:
total = min(total, args.max_batches)
logprog = LogProgress(logger, data_loader, total=total,
updates=self.args.misc.num_prints, name=name)
averager = EMA()
for idx, sources in enumerate(logprog):
sources = sources.to(self.device)
if train:
sources = self.augment(sources)
mix = sources.sum(dim=1)
else:
mix = sources[:, 0]
sources = sources[:, 1:]
if not train and self.args.valid_apply:
estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
else:
estimate = self.dmodel(mix)
if train and hasattr(self.model, 'transform_target'):
sources = self.model.transform_target(mix, sources)
assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
dims = tuple(range(2, sources.dim()))
if args.optim.loss == 'l1':
loss = F.l1_loss(estimate, sources, reduction='none')
loss = loss.mean(dims).mean(0)
reco = loss
elif args.optim.loss == 'mse':
loss = F.mse_loss(estimate, sources, reduction='none')
loss = loss.mean(dims)
reco = loss**0.5
reco = reco.mean(0)
else:
raise ValueError(f"Invalid loss {self.args.loss}")
weights = torch.tensor(args.weights).to(sources)
loss = (loss * weights).sum() / weights.sum()
ms = 0
if self.quantizer is not None:
ms = self.quantizer.model_size()
if args.quant.diffq:
loss += args.quant.diffq * ms
losses = {}
losses['reco'] = (reco * weights).sum() / weights.sum()
losses['ms'] = ms
if not train:
nsdrs = new_sdr(sources, estimate.detach()).mean(0)
total = 0
for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
losses[f'nsdr_{source}'] = nsdr
total += w * nsdr
losses['nsdr'] = total / weights.sum()
if train and args.svd.penalty > 0:
kw = dict(args.svd)
kw.pop('penalty')
penalty = svd_penalty(self.model, **kw)
losses['penalty'] = penalty
loss += args.svd.penalty * penalty
losses['loss'] = loss
for k, source in enumerate(self.model.sources):
losses[f'reco_{source}'] = reco[k]
# optimize model in training mode
if train:
loss.backward()
grad_norm = 0
grads = []
for p in self.model.parameters():
if p.grad is not None:
grad_norm += p.grad.data.norm()**2
grads.append(p.grad.data)
losses['grad'] = grad_norm ** 0.5
if args.optim.clip_grad:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
args.optim.clip_grad)
if self.args.flag == 'uns':
for n, p in self.model.named_parameters():
if p.grad is None:
print('no grad', n)
self.optimizer.step()
self.optimizer.zero_grad()
for ema in self.emas['batch']:
ema.update()
losses = averager(losses)
logs = self._format_train(losses)
logprog.update(**logs)
# Just in case, clear some memory
del loss, estimate, reco, ms
if args.max_batches == idx:
break
if self.args.debug and train:
break
if self.args.flag == 'debug':
break
if train:
for ema in self.emas['epoch']:
ema.update()
return distrib.average(losses, idx + 1)