in sing/fondation/trainer.py [0:0]
def train(self):
"""
Train :attr:`model` for :attr:`epochs`
"""
last_epoch, state = utils.load_checkpoint(self.checkpoint_path)
if state is not None:
self.model.load_state_dict(state, strict=False)
start_epoch = last_epoch + 1
if start_epoch > self.epochs:
raise ValueError(("Checkpoint has been trained for {} "
"epochs but we aim for {} epochs").format(
start_epoch, self.epochs))
if start_epoch > 0:
print("Resuming training at epoch {}".format(start_epoch))
for epoch in range(start_epoch, self.epochs):
self._train_epoch(self.train_dataset, epoch)
utils.save_checkpoint(self.checkpoint_path, epoch,
self.model.state_dict())
with torch.no_grad():
for name, dataset in self.eval_datasets.items():
self._eval_dataset(name, dataset, epoch)