in src/baselines/lgbm.py [0:0]
def trainer(args, train_dataset, dev_dataset, log_fp=None):
if log_fp is None:
log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
feature_size = train_dataset.num_feature()
feature_names = ["f_%d" % idx for idx in range(feature_size)]
train_num = train_dataset.num_data()
if args.max_metric_type == "acc":
feval = acc
elif args.max_metric_type == "prec":
feval = precision
elif args.max_metric_type == "recall":
feval = recall
elif args.max_metric_type == "f1":
feval = f1
elif args.max_metric_type == "pr_auc":
feval = pr_auc
elif args.max_metric_type == "roc_auc":
feval = roc_auc
else:
feval = f1
param = get_model_param(args)
log_fp.write(json.dumps(param, ensure_ascii=False) + "\n")
log_fp.write("#" * 50 + "\n")
log_fp.flush()
evals_result = {}
lgbm_model = lgb.train(
params=param,
train_set=train_dataset,
num_boost_round=args.num_train_epochs,
valid_sets=[train_dataset, dev_dataset],
valid_names=["train", "dev"],
callbacks=[lgb.record_evaluation(evals_result)],
feval=feval,
feature_name=feature_names,
early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None,
verbose_eval=1,
keep_training_booster=True
)
feature_importance_info = pd.DataFrame()
feature_importance_info["feature"] = feature_names
feature_importance_info["importance"] = lgbm_model.feature_importance()
log_fp.write("best iteration: %d\n" % lgbm_model.best_iteration)
log_fp.write("num trees: %d\n" % lgbm_model.num_trees())
global_step = lgbm_model.best_iteration
prefix = "checkpoint-{}".format(global_step)
checkpoint = os.path.join(args.output_dir, prefix)
if not os.path.exists(checkpoint):
os.makedirs(checkpoint)
log_fp.write(str(evals_result) + "\n" + "#" * 50 + "\n")
tr_loss = evals_result["train"]["binary_logloss"][global_step - 1]
max_metric_model_info = {"global_step": global_step}
save_model(lgbm_model, os.path.join(checkpoint, "lgbm_model.txt"))
param["best_iteration"] = lgbm_model.best_iteration
json.dump(param, open(os.path.join(checkpoint, "config.json"), "w"))
return global_step, tr_loss, max_metric_model_info