in utils/checkpoint_utils.py [0:0]
def load_checkpoint(opts,
model: torch.nn.Module,
optimizer: Union[BaseOptim, torch.optim.Optimizer],
gradient_scalar: torch.cuda.amp.GradScaler,
model_ema: Optional[torch.nn.Module] = None):
resume_loc = getattr(opts, "common.resume", None)
dev_id = getattr(opts, "dev.device_id", None)
device = getattr(opts, "dev.device", torch.device('cpu'))
start_epoch = start_iteration = 0
best_metric = 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf
auto_resume = getattr(opts, "common.auto_resume", False)
exp_dir = getattr(opts, "common.exp_loc", None)
is_master_node = is_master(opts)
if resume_loc is None and auto_resume and exp_dir is not None:
resume_loc = '{}/checkpoint.{}'.format(exp_dir, CHECKPOINT_EXTN)
resume_loc = get_local_path(opts, path=resume_loc)
if resume_loc is not None and os.path.isfile(resume_loc):
if dev_id is None:
checkpoint = torch.load(resume_loc, map_location=device)
else:
checkpoint = torch.load(resume_loc, map_location='cuda:{}'.format(dev_id))
start_epoch = checkpoint['epoch'] + 1
start_iteration = checkpoint['iterations'] + 1
best_metric = checkpoint['best_metric']
model = load_state_dict(model, checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optim_state_dict'])
gradient_scalar.load_state_dict(checkpoint['gradient_scalar_state_dict'])
if model_ema is not None and 'ema_state_dict' in checkpoint:
model_ema.ema_model = load_state_dict(model_ema.ema_model, checkpoint['ema_state_dict'])
if is_master_node:
logger.log('Loaded checkpoint from {}'.format(resume_loc))
logger.log('Resuming training for epoch {}'.format(start_epoch))
else:
if is_master_node:
logger.log("No checkpoint found at '{}'".format(resume_loc))
return model, optimizer, gradient_scalar, start_epoch, start_iteration, best_metric, model_ema