def load_checkpoint()

in utils/checkpoint_utils.py [0:0]


def load_checkpoint(opts,
                    model: torch.nn.Module,
                    optimizer: Union[BaseOptim, torch.optim.Optimizer],
                    gradient_scalar: torch.cuda.amp.GradScaler,
                    model_ema: Optional[torch.nn.Module] = None):
    resume_loc = getattr(opts, "common.resume", None)
    dev_id = getattr(opts, "dev.device_id", None)
    device = getattr(opts, "dev.device", torch.device('cpu'))
    start_epoch = start_iteration = 0
    best_metric = 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf
    auto_resume = getattr(opts, "common.auto_resume", False)
    exp_dir = getattr(opts, "common.exp_loc", None)
    is_master_node = is_master(opts)
    if resume_loc is None and auto_resume and exp_dir is not None:
        resume_loc = '{}/checkpoint.{}'.format(exp_dir, CHECKPOINT_EXTN)

    resume_loc = get_local_path(opts, path=resume_loc)
    if resume_loc is not None and os.path.isfile(resume_loc):
        if dev_id is None:
            checkpoint = torch.load(resume_loc, map_location=device)
        else:
            checkpoint = torch.load(resume_loc, map_location='cuda:{}'.format(dev_id))

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iterations'] + 1
        best_metric = checkpoint['best_metric']

        model = load_state_dict(model, checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        gradient_scalar.load_state_dict(checkpoint['gradient_scalar_state_dict'])

        if model_ema is not None and 'ema_state_dict' in checkpoint:
            model_ema.ema_model = load_state_dict(model_ema.ema_model, checkpoint['ema_state_dict'])

        if is_master_node:
            logger.log('Loaded checkpoint from {}'.format(resume_loc))
            logger.log('Resuming training for epoch {}'.format(start_epoch))
    else:
        if is_master_node:
            logger.log("No checkpoint found at '{}'".format(resume_loc))
    return model, optimizer, gradient_scalar, start_epoch, start_iteration, best_metric, model_ema