def load_checkpoint()

in pycls/core/checkpoint.py [0:0]


def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None):
    """
    Loads a checkpoint selectively based on the input options.

    Each checkpoint contains both the model and model_ema weights (except checkpoints
    created by old versions of the code). If both the model and model_weights are
    requested, both sets of weights are loaded. If only the model weights are requested
    (that is if model_ema=None), the *better* set of weights is selected to be loaded
    (according to the lesser of test_err and ema_err, also stored in the checkpoint).

    The code is backward compatible with checkpoints that do not store the ema weights.
    """
    err_str = "Checkpoint '{}' not found"
    assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
    with pathmgr.open(checkpoint_file, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
    # Get test_err and ema_err (with backward compatibility)
    test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100
    ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100
    # Load model and optionally model_ema weights (with backward compatibility)
    ema_state = "ema_state" if "ema_state" in checkpoint else "model_state"
    if model_ema:
        unwrap_model(model).load_state_dict(checkpoint["model_state"])
        unwrap_model(model_ema).load_state_dict(checkpoint[ema_state])
    else:
        best_state = "model_state" if test_err <= ema_err else ema_state
        unwrap_model(model).load_state_dict(checkpoint[best_state])
    # Load optimizer if requested
    if optimizer:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    return checkpoint["epoch"], test_err, ema_err