in utils.py [0:0]
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
output_dir = Path(args.output_dir)
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
args.start_epoch = checkpoint['epoch'] + 1
else:
assert args.eval, 'Does not support resuming with checkpoint-best'
if hasattr(args, 'model_ema') and args.model_ema:
if 'model_ema' in checkpoint.keys():
model_ema.ema.load_state_dict(checkpoint['model_ema'])
else:
model_ema.ema.load_state_dict(checkpoint['model'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")