in denoiser/solver.py [0:0]
def _reset(self):
"""_reset."""
load_from = None
load_best = False
keep_history = True
# Reset
if self.checkpoint and self.checkpoint_file.exists() and not self.restart:
load_from = self.checkpoint_file
elif self.continue_from:
load_from = self.continue_from
load_best = self.args.continue_best
keep_history = False
if load_from:
logger.info(f'Loading checkpoint model: {load_from}')
package = torch.load(load_from, 'cpu')
if load_best:
self.model.load_state_dict(package['best_state'])
else:
self.model.load_state_dict(package['model']['state'])
if 'optimizer' in package and not load_best:
self.optimizer.load_state_dict(package['optimizer'])
if keep_history:
self.history = package['history']
self.best_state = package['best_state']
continue_pretrained = self.args.continue_pretrained
if continue_pretrained:
logger.info("Fine tuning from pre-trained model %s", continue_pretrained)
model = getattr(pretrained, self.args.continue_pretrained)()
self.model.load_state_dict(model.state_dict())