def main()

in training/train_xgboost.py [0:0]


def main():

    args = parse_args()
    train_files_path, validation_files_path = args.train, args.validation

    train_features_path = os.path.join(args.train, "train_features.csv")
    train_labels_path = os.path.join(args.train, "train_labels.csv")

    val_features_path = os.path.join(args.validation, "val_features.csv")
    val_labels_path = os.path.join(args.validation, "val_labels.csv")

    print("Loading training dataframes...")
    df_train_features = pd.read_csv(train_features_path)
    df_train_labels = pd.read_csv(train_labels_path)

    print("Loading validation dataframes...")
    df_val_features = pd.read_csv(val_features_path)
    df_val_labels = pd.read_csv(val_labels_path)

    X = df_train_features.values
    y = df_train_labels.values

    val_X = df_val_features.values
    val_y = df_val_labels.values

    dtrain = xgboost.DMatrix(X, label=y)
    dval = xgboost.DMatrix(val_X, label=val_y)

    watchlist = [(dtrain, "train"), (dval, "validation")]

    params = {
        "max_depth": args.max_depth,
        "eta": args.eta,
        "gamma": args.gamma,
        "min_child_weight": args.min_child_weight,
        "silent": args.silent,
        "objective": args.objective,
        "subsample": args.subsample,
        "eval_metric": args.eval_metric,
        "early_stopping_rounds": args.early_stopping_rounds,
    }

    bst = xgboost.train(
        params=params, dtrain=dtrain, evals=watchlist, num_boost_round=args.num_round
    )

    model_dir = os.environ.get("SM_MODEL_DIR")
    pkl.dump(bst, open(model_dir + "/model.bin", "wb"))