in svoice/solver.py [0:0]
def __init__(self, data, model, optimizer, args):
self.tr_loader = data['tr_loader']
self.cv_loader = data['cv_loader']
self.tt_loader = data['tt_loader']
self.model = model
self.dmodel = distrib.wrap(model)
self.optimizer = optimizer
if args.lr_sched == 'step':
self.sched = StepLR(
self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma)
elif args.lr_sched == 'plateau':
self.sched = ReduceLROnPlateau(
self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience)
else:
self.sched = None
# Training config
self.device = args.device
self.epochs = args.epochs
self.max_norm = args.max_norm
# Checkpoints
self.continue_from = args.continue_from
self.eval_every = args.eval_every
self.checkpoint = Path(
args.checkpoint_file) if args.checkpoint else None
if self.checkpoint:
logger.debug("Checkpoint will be saved to %s",
self.checkpoint.resolve())
self.history_file = args.history_file
self.best_state = None
self.restart = args.restart
# keep track of losses
self.history = []
# Where to save samples
self.samples_dir = args.samples_dir
# logging
self.num_prints = args.num_prints
# for seperation tests
self.args = args
self._reset()