def main()

in 6-Pipelines/config/xgboost_customer_churn.py [0:0]


def main():
    
    args = parse_args()

    train, validation = args.train, args.validation
    parse_csv = "?format=csv&label_column=0"
    dtrain = xgboost.DMatrix(train+parse_csv)
    dval = xgboost.DMatrix(validation+parse_csv)

    watchlist = [(dtrain, "train"), (dval, "validation")]

    params = {
        "max_depth": args.max_depth,
        "eta": args.eta,
        "gamma": args.gamma,
        "min_child_weight": args.min_child_weight,
        "subsample": args.subsample,
        "verbosity": args.verbosity,
        "objective": args.objective}

    # The output_uri is a the URI for the s3 bucket where the metrics will be
    # saved.
    output_uri = (
        args.smdebug_path
        if args.smdebug_path is not None
        else args.output_uri
    )

    collections = (
        args.smdebug_collections.split(',')
        if args.smdebug_collections is not None
        else None
    )

    hook = create_smdebug_hook(
        out_dir=output_uri,
        frequency=args.smdebug_frequency,
        collections=collections,
        train_data=dtrain,
        validation_data=dval,
    )

    bst = xgboost.train(
        params=params,
        dtrain=dtrain,
        evals=watchlist,
        num_boost_round=args.num_round,
        callbacks=[hook])
    
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    model_location = os.path.join(args.model_dir, 'xgboost-model')
    pickle.dump(bst, open(model_location, 'wb'))