in 6-Pipelines/config/xgboost_customer_churn.py [0:0]
def main():
args = parse_args()
train, validation = args.train, args.validation
parse_csv = "?format=csv&label_column=0"
dtrain = xgboost.DMatrix(train+parse_csv)
dval = xgboost.DMatrix(validation+parse_csv)
watchlist = [(dtrain, "train"), (dval, "validation")]
params = {
"max_depth": args.max_depth,
"eta": args.eta,
"gamma": args.gamma,
"min_child_weight": args.min_child_weight,
"subsample": args.subsample,
"verbosity": args.verbosity,
"objective": args.objective}
# The output_uri is a the URI for the s3 bucket where the metrics will be
# saved.
output_uri = (
args.smdebug_path
if args.smdebug_path is not None
else args.output_uri
)
collections = (
args.smdebug_collections.split(',')
if args.smdebug_collections is not None
else None
)
hook = create_smdebug_hook(
out_dir=output_uri,
frequency=args.smdebug_frequency,
collections=collections,
train_data=dtrain,
validation_data=dval,
)
bst = xgboost.train(
params=params,
dtrain=dtrain,
evals=watchlist,
num_boost_round=args.num_round,
callbacks=[hook])
if not os.path.exists(args.model_dir):
os.makedirs(args.model_dir)
model_location = os.path.join(args.model_dir, 'xgboost-model')
pickle.dump(bst, open(model_location, 'wb'))