def trainer()

in src/baselines/xgb.py [0:0]


def trainer(args, train_dataset, dev_dataset, log_fp=None):
    if log_fp is None:
        log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
    param = get_model_param(args)

    log_fp.write(json.dumps(param, ensure_ascii=False) + "\n")
    log_fp.write("#" * 50 + "\n")
    log_fp.flush()

    evals_result = {}
    if args.max_metric_type == "acc":
        feval = acc
    elif args.max_metric_type == "prec":
        feval = precision
    elif args.max_metric_type == "recall":
        feval = recall
    elif args.max_metric_type == "f1":
        feval = f1
    elif args.max_metric_type == "pr_auc":
        feval = pr_auc
    elif args.max_metric_type == "roc_auc":
        feval = roc_auc
    else:
        feval = f1
    xgb_model = xgboost.train(
        param,
        train_dataset,
        num_boost_round=args.num_train_epochs,
        evals=[(train_dataset, "train"), (dev_dataset, "dev")],
        obj=None,
        custom_metric=feval,
        maximize=True,
        early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None,
        evals_result=evals_result,
        verbose_eval=True,
        xgb_model=None,
        callbacks=None
    )
    log_fp.write("best iteration: %d\n" % xgb_model.best_iteration)
    log_fp.write("num trees: %d\n" % xgb_model.num_trees())
    global_step = xgb_model.best_iteration
    prefix = "checkpoint-{}".format(global_step)
    checkpoint = os.path.join(args.output_dir, prefix)
    if not os.path.exists(checkpoint):
        os.makedirs(checkpoint)
    xgboost.plot_importance(xgb_model)
    plt.savefig(os.path.join(checkpoint, "feature_importance.png"), dpi=100)
    plt.close("all")
    log_fp.write(str(evals_result) + "\n" + "#" * 50 + "\n")
    tr_loss = evals_result["train"]["logloss"][global_step - 1]
    max_metric_model_info = {"global_step": global_step}
    save_model(xgb_model, os.path.join(checkpoint, "xgb_model.txt"))
    param["best_iteration"] = xgb_model.best_iteration
    json.dump(param, open(os.path.join(checkpoint, "config.json"), "w"))
    return global_step, tr_loss, max_metric_model_info