def train()

in src/sagemaker_xgboost_container/checkpointing.py [0:0]


def train(train_args, checkpoint_dir):
    """Convenience function for script mode.

    Instead of running xgb.train(params, dtrain, ...), users can enable
    checkpointing in script mode by creating a dictionary of xgb.train
    arguments:

    train_args = dict(
        params=params,
        dtrain=dtrain,
        num_boost_round=num_round,
        evals=[(dtrain, 'train'), (dtest, 'test')]
    )

    and calling:

    bst = checkpointing.train(train_args)
    """
    train_args = train_args.copy()

    xgb_model, start_iteration = load_checkpoint(checkpoint_dir)

    # xgboost's default value for num_boost_round is 10.
    # https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training
    # If num_boost_round <= 0, xgb.train() doesn't actually train and
    # immediately returns a Booster object.
    train_args["num_boost_round"] = train_args.get("num_boost_round", 10) - start_iteration

    if xgb_model is not None:
        logging.info("Checkpoint loaded from %s", xgb_model)
        logging.info("Resuming from iteration %s", start_iteration)

    callbacks = train_args.get("callbacks", [])
    callbacks.append(print_checkpointed_evaluation(start_iteration=start_iteration,
                                                   end_iteration=train_args["num_boost_round"]))
    callbacks.append(save_checkpoint(checkpoint_dir, start_iteration=start_iteration, iteration=start_iteration,
                                     end_iteration=train_args["num_boost_round"]))

    train_args["verbose_eval"] = False  # suppress xgboost's print_evaluation()
    train_args["xgb_model"] = xgb_model
    train_args["callbacks"] = callbacks

    booster = xgb.train(**train_args)

    return booster