in utils/saver.py [0:0]
def load_checkpoints(args, model, optimizer=None, lr_scheduler=None, logger=None):
resume_path = args.resume
assert os.path.isfile(resume_path), "=> no checkpoint found at '{}'".format(resume_path)
with open(resume_path, 'rb') as f:
checkpoint = torch.load(f, map_location=torch.device('cpu'))
if logger:
logger.info("=> loading checkpoint '{}'".format(resume_path))
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
resume_with_a_different_optimizer = getattr(args, 'resume_with_a_different_optimizer', False)
resume_with_a_different_lr_scheduler = getattr(args, 'resume_with_a_different_lr_scheduler', False)
if optimizer and not resume_with_a_different_optimizer:
optimizer.load_state_dict(checkpoint['optimizer'])
if lr_scheduler and not resume_with_a_different_optimizer and not resume_with_a_different_lr_scheduler:
# use lr_scheduler settings defined in args
skip_keys = list(args.lr_scheduler.__dict__.keys()) + ['clamp_lr']
for k in skip_keys:
if k in checkpoint['lr_scheduler']:
checkpoint['lr_scheduler'].pop(k)
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
elif lr_scheduler is not None:
# reset lr_scheduler start epoch only
lr_scheduler.step(checkpoint['lr_scheduler']['last_epoch'])
if logger:
logger.info("=> loaded checkpoint '{}' (epoch {})"
.format(resume_path, checkpoint['epoch']))
del checkpoint