def load_checkpoint()

in src/sagemaker_xgboost_container/checkpointing.py [0:0]


def load_checkpoint(checkpoint_dir, max_try=5):
    """
    :param checkpoint_dir: e.g., /opt/ml/checkpoints
    :param max_try: number of times to try loading checkpoint before giving up.
    :return xgb_model: file path of stored xgb model. None if no checkpoint.
    :return iteration: iterations completed before last checkpoint.
    """
    if not checkpoint_dir or not os.path.exists(checkpoint_dir):
        return None, 0

    regex = r"^{0}\.[0-9]+$".format(CHECKPOINT_FILENAME)
    checkpoints = [f for f in os.listdir(checkpoint_dir) if re.match(regex, f)]
    if not checkpoints:
        return None, 0
    _sort_checkpoints(checkpoints)

    xgb_model, iteration = None, 0

    for _ in range(max_try):
        try:
            latest_checkpoint = checkpoints.pop()
            xgb_model = os.path.join(checkpoint_dir, latest_checkpoint)
            filename, extension = latest_checkpoint.split(".")
            iteration = int(extension) + 1
            break
        except XGBoostError:
            logging.debug("Wrong checkpoint model format %s", latest_checkpoint)

    return xgb_model, iteration