in src/sagemaker_xgboost_container/checkpointing.py [0:0]
def load_checkpoint(checkpoint_dir, max_try=5):
"""
:param checkpoint_dir: e.g., /opt/ml/checkpoints
:param max_try: number of times to try loading checkpoint before giving up.
:return xgb_model: file path of stored xgb model. None if no checkpoint.
:return iteration: iterations completed before last checkpoint.
"""
if not checkpoint_dir or not os.path.exists(checkpoint_dir):
return None, 0
regex = r"^{0}\.[0-9]+$".format(CHECKPOINT_FILENAME)
checkpoints = [f for f in os.listdir(checkpoint_dir) if re.match(regex, f)]
if not checkpoints:
return None, 0
_sort_checkpoints(checkpoints)
xgb_model, iteration = None, 0
for _ in range(max_try):
try:
latest_checkpoint = checkpoints.pop()
xgb_model = os.path.join(checkpoint_dir, latest_checkpoint)
filename, extension = latest_checkpoint.split(".")
iteration = int(extension) + 1
break
except XGBoostError:
logging.debug("Wrong checkpoint model format %s", latest_checkpoint)
return xgb_model, iteration