in src/baselines/xgb.py [0:0]
def trainer(args, train_dataset, dev_dataset, log_fp=None):
if log_fp is None:
log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
param = get_model_param(args)
log_fp.write(json.dumps(param, ensure_ascii=False) + "\n")
log_fp.write("#" * 50 + "\n")
log_fp.flush()
evals_result = {}
if args.max_metric_type == "acc":
feval = acc
elif args.max_metric_type == "prec":
feval = precision
elif args.max_metric_type == "recall":
feval = recall
elif args.max_metric_type == "f1":
feval = f1
elif args.max_metric_type == "pr_auc":
feval = pr_auc
elif args.max_metric_type == "roc_auc":
feval = roc_auc
else:
feval = f1
xgb_model = xgboost.train(
param,
train_dataset,
num_boost_round=args.num_train_epochs,
evals=[(train_dataset, "train"), (dev_dataset, "dev")],
obj=None,
custom_metric=feval,
maximize=True,
early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None,
evals_result=evals_result,
verbose_eval=True,
xgb_model=None,
callbacks=None
)
log_fp.write("best iteration: %d\n" % xgb_model.best_iteration)
log_fp.write("num trees: %d\n" % xgb_model.num_trees())
global_step = xgb_model.best_iteration
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
if not os.path.exists(checkpoint):
os.makedirs(checkpoint)
xgboost.plot_importance(xgb_model)
plt.savefig(os.path.join(checkpoint, "feature_importance.png"), dpi=100)
plt.close("all")
log_fp.write(str(evals_result) + "\n" + "#" * 50 + "\n")
tr_loss = evals_result["train"]["logloss"][global_step - 1]
max_metric_model_info = {"global_step": global_step}
save_model(xgb_model, os.path.join(checkpoint, "xgb_model.txt"))
param["best_iteration"] = xgb_model.best_iteration
json.dump(param, open(os.path.join(checkpoint, "config.json"), "w"))
return global_step, tr_loss, max_metric_model_info