def train_with_grid_search()

in src/baselines/xgb.py [0:0]


def train_with_grid_search(args, X_train, y_train, X_dev, y_dev, log_fp=None):
    parameters = {
        "max_depth": [9, 15, 17, 20, 25, 30, 35, 39],
        "n_estimators": [50, 100, 200, 250, 300],
        "learning_rate": [0.001, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2],
        "min_child_weight": [0, 1, 3, 5, 7],
        "max_delta_step": [0, 0.2, 0.6, 1, 2],
        "subsample": [0.6, 0.7, 0.8, 0.85, 0.95],
        "colsample_bytree": [0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
        "reg_alpha": [0, 0.25, 0.5, 0.75, 1],
        "reg_lambda": [0.2, 0.4, 0.6, 0.8, 1],
        "scale_pos_weight": [1, 10, 20, 40]
    }
    if args.max_metric_type == "acc":
        feval = "accuracy"
    elif args.max_metric_type == "prec":
        feval = "precision"
    elif args.max_metric_type == "recall":
        feval = "recall"
    elif args.max_metric_type == "f1":
        feval = "f1_macro"
    elif args.max_metric_type == "pr_auc":
        feval = "average_precision"
    elif args.max_metric_type == "roc_auc":
        feval = "roc_auc"
    else:
        feval = "f1_macro"
    param = get_model_param(args)

    xgb = XGBClassifier(
        objective=param["objective"],
        base_score=0.5,
        booster=param["booster"],
        colsample_bylevel=1,
        colsample_bynode=1,
        colsample_bytree=param["colsample_bytree"],
        gamma=param["gamma"],
        gpu_id=args.gpu,
        importance_type='gain',
        interaction_constraints='',
        learning_rate=param["eta"],
        max_delta_step=0,
        max_depth=param["max_depth"],
        min_child_weight=param["min_child_weight"],
        missing=np.nan,
        monotone_constraints='()',
        n_estimators=100,
        n_jobs=param["nthread"],
        num_parallel_tree=1,
        random_state=args.seed,
        reg_alpha=0,
        reg_lambda=param["lambda"],
        scale_pos_weight=args.pos_weight,
        subsample=param["subsample"],
        tree_method='exact', # exact, approx, hist, gpu_hist
        # validate_parameters=1,
        verbosity=param["verbosity"],
        eval_metric=['auc', 'logloss']
    )

    # gridsearch without appling fit function
    train_val_features = np.concatenate((X_train, X_dev), axis=0)
    train_val_labels = np.concatenate((y_train, y_dev), axis=0)

    dev_fold = np.zeros(train_val_features.shape[0])
    dev_fold[:X_train.shape[0]] = -1
    ps = PredefinedSplit(test_fold=dev_fold)

    # gsearch = GridSearchCV(xgb, param_grid=parameters, scoring=feval, cv=ps)
    gsearch = RandomizedSearchCV(xgb, param_distributions=parameters,
                                 scoring=feval, cv=ps)
    gsearch.fit(train_val_features,
                train_val_labels,
                eval_set=[(X_dev, y_dev)],
                early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None,
                verbose=1)

    logger.info("Best score[%s]: %0.6f" % (args.max_metric_type, gsearch.best_score_))
    logger.info("Best parameters set:")
    log_fp.write("Best score[%s]: %0.6f\n" % (args.max_metric_type, gsearch.best_score_))
    log_fp.write("Best parameters set:\n")
    best_parameters = gsearch.best_estimator_.get_params()
    for param_name in sorted(parameters.keys()):
        logger.info("\t%s: %r" % (param_name, best_parameters[param_name]))
        log_fp.write("%s: %r\n" % (param_name, best_parameters[param_name]))
    log_fp.write("#" * 50 + "\n")
    global_step = best_parameters["n_estimators"]
    prefix = "checkpoint-{}".format(global_step)
    checkpoint = os.path.join(args.output_dir, prefix)
    if not os.path.exists(checkpoint):
        os.makedirs(checkpoint)
    # save gridsearch
    joblib.dump(gsearch, os.path.join(args.output_dir, "xgb_gridsearch.pkl"))

    log_fp.write(str(gsearch.best_estimator_) + "\n" + "#" * 50 + "\n")

    tr_loss = 0
    max_metric_model_info = {"global_step": global_step}
    save_model(gsearch.best_estimator_, os.path.join(checkpoint, "xgb_model.txt"))
    json.dump(best_parameters, open(os.path.join(checkpoint, "config.json"), "w"))
    return global_step, tr_loss, max_metric_model_info