in src/sagemaker_xgboost_container/checkpointing.py [0:0]
def train(train_args, checkpoint_dir):
"""Convenience function for script mode.
Instead of running xgb.train(params, dtrain, ...), users can enable
checkpointing in script mode by creating a dictionary of xgb.train
arguments:
train_args = dict(
params=params,
dtrain=dtrain,
num_boost_round=num_round,
evals=[(dtrain, 'train'), (dtest, 'test')]
)
and calling:
bst = checkpointing.train(train_args)
"""
train_args = train_args.copy()
xgb_model, start_iteration = load_checkpoint(checkpoint_dir)
# xgboost's default value for num_boost_round is 10.
# https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training
# If num_boost_round <= 0, xgb.train() doesn't actually train and
# immediately returns a Booster object.
train_args["num_boost_round"] = train_args.get("num_boost_round", 10) - start_iteration
if xgb_model is not None:
logging.info("Checkpoint loaded from %s", xgb_model)
logging.info("Resuming from iteration %s", start_iteration)
callbacks = train_args.get("callbacks", [])
callbacks.append(print_checkpointed_evaluation(start_iteration=start_iteration,
end_iteration=train_args["num_boost_round"]))
callbacks.append(save_checkpoint(checkpoint_dir, start_iteration=start_iteration, iteration=start_iteration,
end_iteration=train_args["num_boost_round"]))
train_args["verbose_eval"] = False # suppress xgboost's print_evaluation()
train_args["xgb_model"] = xgb_model
train_args["callbacks"] = callbacks
booster = xgb.train(**train_args)
return booster