in demucs/solver.py [0:0]
def __init__(self, loaders, model, optimizer, args):
self.args = args
self.loaders = loaders
self.model = model
self.optimizer = optimizer
self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
self.dmodel = distrib.wrap(model)
self.device = next(iter(self.model.parameters())).device
# Exponential moving average of the model, either updated every batch or epoch.
# The best model from all the EMAs and the original one is kept based on the valid
# loss for the final best model.
self.emas = {'batch': [], 'epoch': []}
for kind in self.emas.keys():
decays = getattr(args.ema, kind)
device = self.device if kind == 'batch' else 'cpu'
if decays:
for decay in decays:
self.emas[kind].append(ModelEMA(self.model, decay, device=device))
# data augment
augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
same=args.augment.shift_same)]
if args.augment.flip:
augments += [augment.FlipChannels(), augment.FlipSign()]
for aug in ['scale', 'remix']:
kw = getattr(args.augment, aug)
if kw.proba:
augments.append(getattr(augment, aug.capitalize())(**kw))
self.augment = torch.nn.Sequential(*augments)
xp = get_xp()
self.folder = xp.folder
# Checkpoints
self.checkpoint_file = xp.folder / 'checkpoint.th'
self.best_file = xp.folder / 'best.th'
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
self.best_state = None
self.best_changed = False
self.link = xp.link
self.history = self.link.history
self._reset()