def _xgb_train()

in xgboost_script_mode_managed_spot_training_checkpointing/abalone.py [0:0]


def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master, checkpoint_path):
    """Run xgb train on arguments given with rabit initialized.

    This is our rabit execution function.

    :param args_dict: Argument dictionary used to run xgb.train().
    :param is_master: True if current node is master host in distributed training,
                        or is running single node training job.
                        Note that rabit_run will include this argument.
    """
    
    logging.info("params: {}, num_boost_round: {}, checkpoint_path: {}".format(params, num_boost_round, checkpoint_path))
    
    train_args = dict(
        params=params,
        dtrain=dtrain,
        num_boost_round=num_boost_round,
        evals=evals
    )

    booster = checkpointing.train(train_args, checkpoint_path)

    if is_master:
        model_location = model_dir + '/xgboost-model'
        pkl.dump(booster, open(model_location, 'wb'))
        logging.info("Stored trained model at {}".format(model_location))