src/baselines/lgbm.py (602 lines of code) (raw):

#!/usr/bin/env python # encoding: utf-8 ''' *Copyright (c) 2023, Alibaba Group; *Licensed under the Apache License, Version 2.0 (the "License"); *you may not use this file except in compliance with the License. *You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 *Unless required by applicable law or agreed to in writing, software *distributed under the License is distributed on an "AS IS" BASIS, *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *See the License for the specific language governing permissions and *limitations under the License. @author: Hey @email: sanyuan.**@**.com @tel: 137****6540 @datetime: 2022/12/29 19:35 @project: DeepProtFunc @file: lgbm @desc: LGBM (based protein structure embeddding (matrix or vector) for classification ''' import json import logging import os, sys import pickle5 import argparse import joblib import lightgbm as lgb from lightgbm import LGBMClassifier from sklearn.model_selection import GridSearchCV, PredefinedSplit, RandomizedSearchCV import pandas as pd sys.path.append(".") sys.path.append("..") sys.path.append("../../") sys.path.append("../../src") try: from utils import * from multi_label_metrics import * from metrics import * except ImportError: from src.utils import * from src.common.multi_label_metrics import * from src.common.metrics import * logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default=None, type=str, required=True, help="input dir, including: *.csv/*.txt.") parser.add_argument("--separate_file", action="store_true", help="the id of each sample in the dataset is separate from its details") parser.add_argument("--tfrecords", action="store_true", help="whether the dataset is in tfrecords") parser.add_argument("--filename_pattern", default=None, type=str, help="the dataset filename pattern, such as {}_with_pdb_emb.csv including train_with_pdb_emb.csv, dev_with_pdb_emb.csv, and test_with_pdb_emb.csv in ${data_dir}") parser.add_argument("--dataset_name", default="rdrp_40_extend", type=str, required=True, help="dataset name") parser.add_argument("--dataset_type", default="protein", type=str, required=True, choices=["protein", "dna", "rna"], help="dataset type") parser.add_argument("--task_type", default="multi_label", type=str, required=True, choices=["multi_label", "multi_class", "binary_class"], help="task type") parser.add_argument("--label_type", default=None, type=str, required=True, help="label type.") parser.add_argument("--label_filepath", default=None, type=str, required=True, help="label filepath.") parser.add_argument("--output_dir", default="./result/", type=str, required=True, help="output dir.") parser.add_argument("--log_dir", default="./logs/", type=str, required=True, help="log dir.") parser.add_argument("--tb_log_dir", default="./tb-logs/", type=str, required=True, help="tensorboard log dir.") # Other parameters parser.add_argument("--config_path", default=None, type=str, required=True, help="the model configuration filepath") parser.add_argument("--cache_dir", default="", type=str, help="cache dir") parser.add_argument("--do_train", action="store_true", help="whether to run training.") parser.add_argument("--do_eval", action="store_true", help="whether to run eval on the dev set.") parser.add_argument("--do_predict", action="store_true", help="whether to run predict on the test set.") parser.add_argument("--evaluate_during_training", action="store_true", help="whether to evaluation during training at each logging step.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="batch size per GPU/CPU for evaluation.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--learning_rate", default=1e-4, type=float, help="the initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="max gradient norm.") parser.add_argument("--num_train_epochs", default=3.0, type=int, help="total number of training epochs to perform.") parser.add_argument("--max_steps", default=-1, type=int, help="if > 0: set total number of training steps to perform. Override num_train_epochs.") parser.add_argument("--warmup_steps", default=0, type=int, help="linear warmup over warmup_steps.") parser.add_argument("--logging_steps", type=int, default=50, help="log every X updates steps.") parser.add_argument("--save_steps", type=int, default=50, help="save checkpoint every X updates steps.") parser.add_argument("--overwrite_output_dir", action="store_true", help="overwrite the content of the output directory") parser.add_argument("--overwrite_cache", action="store_true", help="overwrite the cached training and evaluation sets") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--local_rank", type=int, default=-1, help="for distributed training: local_rank") parser.add_argument("--unbalance", action="store_true", help="set unbalance.") parser.add_argument("--pos_weight", type=float, default=40, help="positive weight") # which metric for model finalization selected parser.add_argument("--max_metric_type", type=str, default="f1", required=True, choices=["acc", "jaccard", "prec", "recall", "f1", "fmax", "pr_auc", "roc_auc"], help="which metric for model selected") parser.add_argument("--early_stopping_rounds", default=None, type=int, help="early stopping rounds.") parser.add_argument("--grid_search", action="store_true", help="grid search for params or not") args = parser.parse_args() args.output_mode = args.task_type return args def load_dataset(args, dataset_type): ''' load the dataset :param args: :param dataset_type :return: ''' x = [] y = [] if os.path.exists(args.label_filepath): label_list = load_labels(args.label_filepath, header=True) else: label_list = load_labels(os.path.join(args.data_dir, "label.txt"), header=True) label_map = {name: idx for idx, name in enumerate(label_list)} npz_filpath = os.path.join(args.data_dir, "%s_emb.npz" % dataset_type) if os.path.exists(npz_filpath): npzfile = np.load(npz_filpath, allow_pickle=True) x = npzfile["x"] y = npzfile["y"] else: cnt = 0 if args.filename_pattern: filepath = os.path.join(args.data_dir, args.filename_pattern.format(dataset_type)) else: filepath = os.path.join(args.data_dir, "%s_with_pdb_emb.csv" % dataset_type) header = False header_filter = False if filepath.endswith(".csv"): header = True header_filter = True for row in file_reader(filepath, header=header, header_filter=header_filter): prot_id, seq, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row embedding_filepath = os.path.join(args.data_dir, "embs", emb_filename) if os.path.exists(embedding_filepath): emb = torch.load(embedding_filepath) embedding_info = emb["bos_representations"][36].numpy() x.append(embedding_info) if args.task_type in ["multi-class", "multi_class"]: label = label_map[label] elif args.task_type == "regression": label = float(label) elif args.task_type in ["multi-label", "multi_label"]: if isinstance(label, str): label = [0] * len(label_map) for label_name in eval(label): label_id = label_map[label_name] label[label_id] = 1 else: label = [0] * len(label_map) for label_name in label: label_id = label_map[label_name] label[label_id] = 1 elif args.task_type in ["binary-class", "binary_class"]: label = label_map[label] y.append(label) cnt += 1 if cnt % 10000 == 0: print("done %d" % cnt) x = np.array(x) y = np.array(y) np.savez(npz_filpath, x=x, y=y) print("%s: x.shape: %s, y.shape: %s" %(dataset_type, str(x.shape), str(y.shape))) return x, y, label_list def acc(probs, labeled_data): higher_better = True name = "acc" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_f1 else: func = multi_class_f1 elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_f1 elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_f1 else: func = multi_class_f1 value = func(targets, probs, threshold=0.5) return name, value, higher_better def f1(probs, labeled_data): higher_better = True name = "f1" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_f1 else: func = multi_class_f1 elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_f1 elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_f1 else: func = multi_class_f1 value = func(targets, probs, threshold=0.5, average="macro") return name, value, higher_better def precision(probs, labeled_data): higher_better = True name = "precision" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_precision else: func = multi_class_precision elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_precision elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_precision else: func = multi_class_precision value = func(targets, probs, threshold=0.5, average="macro") return name, value, higher_better def recall(probs, labeled_data): higher_better = True name = "recall" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_recall else: func = multi_class_recall elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_recall elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_recall else: func = multi_class_recall value = func(targets, probs, threshold=0.5, average="macro") return name, value, higher_better def roc_auc(probs, labeled_data): higher_better = True name = "roc_auc" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_roc_auc else: func = multi_class_roc_auc elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_roc_auc elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_roc_auc else: func = multi_class_roc_auc value = func(labeled_data.get_label(), probs, threshold=0.5, average="macro") return name, value, higher_better def pr_auc(probs, labeled_data): higher_better = True name = "pr_auc" targets = labeled_data.get_label() if targets.ndim == 1: if len(np.unique(targets)) == 2: func = binary_pr_auc else: func = multi_class_pr_auc elif targets.ndim == 2 and targets.shape[1] > 2: func = multi_label_pr_auc elif targets.ndim == 2 and targets.shape[1] == 2: func = binary_pr_auc else: func = multi_class_pr_auc value = func(targets, probs, threshold=0.5, average="macro") return name, value, higher_better def save_model(model, model_path): print("#" * 25 + "save model" + "#" * 25) # dump model with pickle with open(model_path, "wb") as fout: pickle5.dump(model, fout) def load_model(model_path): print("#" * 25 + "load model" + "#" * 25) with open(model_path, "rb") as fin: model = pickle5.load(fin) return model def get_model_param(args): # LGBM Paramater tuning if args.task_type in ["binary_class", "binary-class"]: objective = "binary" elif args.task_type in ["multi_class", "multi-class"]: objective = "multiclass" elif args.task_type in ["multi_label", "multi-label"]: pass # to do elif args.task_type == "regression": objective = "regression" # regression, regression_l1, mape ''' param = { 'num_leaves': 8, 'learning_rate': 0.01, 'feature_fraction': 0.6, 'max_depth': 17, 'objective': objective, # task type binary, regression, regression_l1, mape, multiclass 'boosting_type': args.boosting_type, # gbdt, rf, dart, goss #'metric': ['auc', 'binary'], # 'metric': ['binary_error'], #'metric': 'custom', 'metric': ['auc', 'binary'], "verbose": 1 } ''' config = json.load(open(args.config_path, "r")) config["objective"] = objective config["learning_rate"] = args.learning_rate if args.unbalance: config["is_unbalance"] = True assert config["boosting_type"] in ["gbdt", "rf", "dart", "goss"] return config def run(): args = main() logging.basicConfig(format="%(asctime)s-%(levelname)s-%(name)s | %(message)s", datefmt="%Y/%m/%d %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) # overwrite the output dir if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) else: os.makedirs(args.output_dir) # create logger dir if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w") # create tensorboard logger dir if not os.path.exists(args.tb_log_dir): os.makedirs(args.tb_log_dir) X_train, y_train, label_list = load_dataset(args, "train") X_dev, y_dev, _ = load_dataset(args, "dev") X_test, y_test, _ = load_dataset(args, "test") if args.unbalance: train_dataset = lgb.Dataset(X_train, label=y_train, free_raw_data=False).construct() dev_dataset = lgb.Dataset(X_dev, label=y_dev, free_raw_data=False).construct() test_dataset = lgb.Dataset(X_test, label=y_test, free_raw_data=False).construct() else: if args.grid_search: train_dataset = lgb.Dataset(X_train, label=y_train, free_raw_data=False).construct() dev_dataset = lgb.Dataset(X_dev, label=y_dev, free_raw_data=False).construct() test_dataset = lgb.Dataset(X_test, label=y_test, free_raw_data=False).construct() else: train_weights = np.ones_like(y_train) train_weights[y_train == 1] = args.pos_weight dev_weights = np.ones_like(y_dev) dev_weights[y_dev == 1] = args.pos_weight test_weights = np.ones_like(y_test) test_weights[y_test == 1] = args.pos_weight train_dataset = lgb.Dataset(X_train, label=y_train, weight=train_weights, free_raw_data=False).construct() dev_dataset = lgb.Dataset(X_dev, label=y_dev, weight=dev_weights, free_raw_data=False).construct() test_dataset = lgb.Dataset(X_test, label=y_test, weight=test_weights, free_raw_data=False).construct() args.num_labels = len(np.unique(y_train)) # output training/evaluation hyperparameters into logger logger.info("==== Training/Evaluation Parameters: =====") for attr, value in sorted(args.__dict__.items()): logger.info("\t{}={}".format(attr, value)) logger.info("==== Parameters End =====\n") args_dict = {} for attr, value in sorted(args.__dict__.items()): if attr != "device": args_dict[attr] = value log_fp.write(json.dumps(args_dict, ensure_ascii=False) + "\n") log_fp.write("#" * 50 + "\n") if args.task_type in ["binary_class", "binary-class"]: n_train_pos = y_train.sum() n_dev_pos = y_dev.sum() n_test_pos = y_test.sum() logger.info("Train data, positive: {}, negative: {}, positive ratio: {:.2f}".format(n_train_pos, len(y_train) - n_train_pos, n_train_pos / len(y_train))) logger.info("Dev data, positive: {}, negative: {}, positive ratio: {:.2f}".format(n_dev_pos, len(y_dev) - n_train_pos, n_dev_pos / len(y_dev))) logger.info("Test data, positive: {}, negative: {}, positive ratio: {:.2f}".format(n_test_pos, len(y_test) - n_test_pos, n_test_pos / len(y_test))) log_fp.write("Train data, positive: {}, negative: {}, positive ratio: {:.2f}\n".format(n_train_pos, len(y_train) - n_train_pos, n_train_pos / len(y_train))) log_fp.write("Dev data, positive: {}, negative: {}, positive ratio: {:.2f}\n".format(n_dev_pos, len(y_dev) - n_train_pos, n_dev_pos / len(y_dev))) log_fp.write("Test data, positive: {}, negative: {}, positive ratio: {:.2f}\n".format(n_test_pos, len(y_test) - n_test_pos, n_test_pos / len(y_test))) log_fp.write("#" * 50 + "\n") log_fp.write("num labels: %d\n" % args.num_labels) log_fp.write("#" * 50 + "\n") log_fp.flush() max_metric_model_info = None if args.do_train: logger.info("++++++++++++Training+++++++++++++") if args.grid_search: global_step, tr_loss, max_metric_model_info = train_with_grid_search(args, train_dataset, dev_dataset, log_fp=log_fp) else: global_step, tr_loss, max_metric_model_info = trainer(args, train_dataset, dev_dataset, log_fp=log_fp) logger.info("global_step = %s, average loss = %s", global_step, tr_loss) # save if args.do_train: logger.info("++++++++++++Save Model+++++++++++++") # Create output directory if needed global_step = max_metric_model_info["global_step"] logger.info("max %s global step: %d" % (args.max_metric_type, global_step)) log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step)) prefix = "checkpoint-{}".format(global_step) checkpoint = os.path.join(args.output_dir, prefix) logger.info("Saving model checkpoint to %s", checkpoint) torch.save(args, os.path.join(checkpoint, "training_args.bin")) save_labels(os.path.join(checkpoint, "label.txt"), label_list) # evaluate if args.do_eval and args.local_rank in [-1, 0]: logger.info("++++++++++++Validation+++++++++++++") log_fp.write("++++++++++++Validation+++++++++++++\n") global_step = max_metric_model_info["global_step"] logger.info("max %s global step: %d" % (args.max_metric_type, global_step)) log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step)) prefix = "checkpoint-{}".format(global_step) checkpoint = os.path.join(args.output_dir, prefix) logger.info("checkpoint path: %s" % checkpoint) log_fp.write("checkpoint path: %s\n" % checkpoint) lgbm_model = load_model(os.path.join(checkpoint, "lgbm_model.txt")) result = evaluate(args, lgbm_model, dev_dataset, prefix=prefix, log_fp=log_fp) result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items()) logger.info(json.dumps(result, ensure_ascii=False)) log_fp.write(json.dumps(result, ensure_ascii=False) + "\n") # Testing if args.do_predict and args.local_rank in [-1, 0]: logger.info("++++++++++++Testing+++++++++++++") log_fp.write("++++++++++++Testing+++++++++++++\n") global_step = max_metric_model_info["global_step"] logger.info("max %s global step: %d" % (args.max_metric_type, global_step)) log_fp.write("max %s global step: %d\n" % (args.max_metric_type, global_step)) prefix = "checkpoint-{}".format(global_step) checkpoint = os.path.join(args.output_dir, prefix) logger.info("checkpoint path: %s" % checkpoint) log_fp.write("checkpoint path: %s\n" % checkpoint) lgbm_model = load_model(os.path.join(checkpoint, "lgbm_model.txt")) result = predict(args, lgbm_model, test_dataset, prefix=prefix, log_fp=log_fp) result = dict(("evaluation_" + k + "_{}".format(global_step), v) for k, v in result.items()) logger.info(json.dumps(result, ensure_ascii=False)) log_fp.write(json.dumps(result, ensure_ascii=False) + "\n") log_fp.close() def train_with_grid_search(args, train_dataset, dev_dataset, log_fp=None): parameters = { "num_leaves": [8, 12, 16, 20], "max_depth": [9, 15, 17, 20, 25, 30, 35, 39], "min_split_gain": [0, 0.05, 0.07, 0.09, 0.1, 0.3, 0.5, 0.7, 0.9, 1], "subsample": [0.6, 0.7, 0.8, 0.8, 1], "n_estimators": [50, 100, 200, 250, 300], "learning_rate": [0.001, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2], "feature_fraction": [0.6, 0.7, 0.8, 0.9, 0.95], "bagging_fraction": [0.6, 0.7, 0.8, 0.9, 0.95], "bagging_freq": [2, 4, 5, 6, 8], "lambda_l1": [0, 0.1, 0.4, 0.5, 0.6], "lambda_l2": [0, 10, 15, 35, 40], "colsample_bytree": [0.6, 0.7, 0.8, 0.9, 1], "reg_alpha": [0, 0.01, 0.02, 0.05, 0.09, 0.1, 1], "reg_lambda": [0, 0.1, 0.5, 1], "cat_smooth": [1, 10, 15, 20, 35] } if args.max_metric_type == "acc": feval = "accuracy" elif args.max_metric_type == "prec": feval = "precision" elif args.max_metric_type == "recall": feval = "recall" elif args.max_metric_type == "f1": feval = "f1_macro" elif args.max_metric_type == "pr_auc": feval = "average_precision" elif args.max_metric_type == "roc_auc": feval = "roc_auc" else: feval = "f1_macro" param = get_model_param(args) if args.output_mode in ["binary_class", "binary-class"]: class_weight = {0: 1, 1: args.pos_weight} else: class_weight = None gbm = LGBMClassifier( boosting_type=param["boosting_type"], num_leaves=param["num_leaves"], max_depth=param["max_depth"], learning_rate=param["learning_rate"], n_estimators=100, subsample_for_bin=200000, objective=param["objective"], num_class=1 if args.num_labels == 2 else args.num_labels, class_weight=class_weight, min_split_gain=0, min_child_weight=1e-3, min_child_samples=20, subsample=1, subsample_freq=0, colsample_bytree=1.0, reg_alpha=0.0, reg_lambda=0.0, random_state=args.seed, n_jobs=-1, silent="warn", importance_type="split" ) # use gridsearch without fit function train_val_features = np.concatenate((train_dataset.get_data(), dev_dataset.get_data()), axis=0) train_val_labels = np.concatenate((train_dataset.get_label(), dev_dataset.get_label()), axis=0) dev_fold = np.zeros(train_val_features.shape[0]) dev_fold[:train_dataset.get_data().shape[0]] = -1 ps = PredefinedSplit(test_fold=dev_fold) # gsearch = GridSearchCV(gbm, param_grid=parameters, scoring=feval, cv=ps) gsearch = RandomizedSearchCV(gbm, param_distributions=parameters, scoring=feval, cv=ps) gsearch.fit(train_val_features, train_val_labels, eval_set=[(dev_dataset.get_data(), dev_dataset.get_label())], eval_metric=['auc', 'binary_logloss'], early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None, verbose=1) logger.info("Best score[%s]: %0.6f" % (args.max_metric_type, gsearch.best_score_)) logger.info("Best parameters set:") log_fp.write("Best score[%s]: %0.6f\n" % (args.max_metric_type, gsearch.best_score_)) log_fp.write("Best parameters set:\n") best_parameters = gsearch.best_estimator_.get_params() for param_name in sorted(parameters.keys()): logger.info("\t%s: %r" % (param_name, best_parameters[param_name])) log_fp.write("%s: %r\n" % (param_name, best_parameters[param_name])) log_fp.write("#" * 50 + "\n") global_step = best_parameters["n_estimators"] prefix = "checkpoint-{}".format(global_step) checkpoint = os.path.join(args.output_dir, prefix) if not os.path.exists(checkpoint): os.makedirs(checkpoint) # save gridsearch joblib.dump(gsearch, os.path.join(args.output_dir, "lgbm_gridsearch.pkl")) log_fp.write(str(gsearch.best_estimator_) + "\n" + "#" * 50 + "\n") tr_loss = 0 max_metric_model_info = {"global_step": global_step} save_model(gsearch.best_estimator_, os.path.join(checkpoint, "lgbm_model.txt")) json.dump(best_parameters, open(os.path.join(checkpoint, "config.json"), "w")) return global_step, tr_loss, max_metric_model_info def trainer(args, train_dataset, dev_dataset, log_fp=None): if log_fp is None: log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w") feature_size = train_dataset.num_feature() feature_names = ["f_%d" % idx for idx in range(feature_size)] train_num = train_dataset.num_data() if args.max_metric_type == "acc": feval = acc elif args.max_metric_type == "prec": feval = precision elif args.max_metric_type == "recall": feval = recall elif args.max_metric_type == "f1": feval = f1 elif args.max_metric_type == "pr_auc": feval = pr_auc elif args.max_metric_type == "roc_auc": feval = roc_auc else: feval = f1 param = get_model_param(args) log_fp.write(json.dumps(param, ensure_ascii=False) + "\n") log_fp.write("#" * 50 + "\n") log_fp.flush() evals_result = {} lgbm_model = lgb.train( params=param, train_set=train_dataset, num_boost_round=args.num_train_epochs, valid_sets=[train_dataset, dev_dataset], valid_names=["train", "dev"], callbacks=[lgb.record_evaluation(evals_result)], feval=feval, feature_name=feature_names, early_stopping_rounds=args.early_stopping_rounds if args.early_stopping_rounds > 0 else None, verbose_eval=1, keep_training_booster=True ) feature_importance_info = pd.DataFrame() feature_importance_info["feature"] = feature_names feature_importance_info["importance"] = lgbm_model.feature_importance() log_fp.write("best iteration: %d\n" % lgbm_model.best_iteration) log_fp.write("num trees: %d\n" % lgbm_model.num_trees()) global_step = lgbm_model.best_iteration prefix = "checkpoint-{}".format(global_step) checkpoint = os.path.join(args.output_dir, prefix) if not os.path.exists(checkpoint): os.makedirs(checkpoint) log_fp.write(str(evals_result) + "\n" + "#" * 50 + "\n") tr_loss = evals_result["train"]["binary_logloss"][global_step - 1] max_metric_model_info = {"global_step": global_step} save_model(lgbm_model, os.path.join(checkpoint, "lgbm_model.txt")) param["best_iteration"] = lgbm_model.best_iteration json.dump(param, open(os.path.join(checkpoint, "config.json"), "w")) return global_step, tr_loss, max_metric_model_info def evaluate(args, model, dataset, prefix, log_fp=None): X, y = dataset.get_data(), dataset.get_label() dev_num = X.shape[0] if args.grid_search: hat_y = model.predict_proba(X) else: hat_y = model.predict(X) if args.output_mode in ["multi-label", "multi_label"]: result = metrics_multi_label(y, hat_y, threshold=0.5) elif args.output_mode in ["multi-class", "multi_class"]: result = metrics_multi_class(y, hat_y) elif args.output_mode == "regression": pass # to do elif args.output_mode in ["binary-class", "binary_class"]: result = metrics_binary(y, hat_y, threshold=0.5, savepath=os.path.join(args.output_dir, "dev_confusion_matrix.png")) with open(os.path.join(args.output_dir, "dev_metrics.txt"), "w") as writer: logger.info("***** Eval Dev results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info("%s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) logger.info("Dev metrics: ") logger.info(json.dumps(result, ensure_ascii=False)) logger.info("") return result def predict(args, model, dataset, prefix, log_fp=None): X, y = dataset.get_data(), dataset.get_label() test_num = X.shape[0] if args.grid_search: hat_y = model.predict_proba(X) else: hat_y = model.predict(X) if args.output_mode in ["multi_class", "multi-class"]: result = metrics_multi_class(y, hat_y) elif args.output_mode in ["multi_label", "multi-label"]: result = metrics_multi_label(y, hat_y, threshold=0.5) elif args.output_mode == "regression": pass # to do elif args.output_mode in ["binary_class", "binary-class"]: result = metrics_binary(y, hat_y, threshold=0.5, savepath=os.path.join(args.output_dir, "test_confusion_matrix.png")) with open(os.path.join(args.output_dir, "test_results.txt"), "w") as wfp: for idx in range(len(y)): wfp.write("%d,%s,%s\n" %(idx, str(hat_y[idx]), str(y[idx]))) with open(os.path.join(args.output_dir, "test_metrics.txt"), "w") as wfp: logger.info("***** Eval Test results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info("%s = %s", key, str(result[key])) wfp.write("%s = %s\n" % (key, str(result[key]))) logger.info("Test metrics: ") logger.info(json.dumps(result, ensure_ascii=False)) logger.info("") return result if __name__ == "__main__": run()