in src/baselines/predict.py [0:0]
def run(args):
'''
main function
:param args:
:return:
'''
# load data
total_records_size = calc_dataset_size(args.data_path)
dataset_num = (total_records_size + (args.split_size - 1)) // args.split_size
dataset_iterator = trange(int(dataset_num), desc="Dataset Iterator", disable=False)
dataset_gen = load_dataset(args.data_path, split_size=args.split_size)
# load model
model = load_model(args.model_path, args.model_type)
# label labels
label_list = load_labels(args.label_filepath, header=True)
label_map = {idx: name for idx, name in enumerate(label_list)}
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
wfp = open(os.path.join(args.save_dir, "pred_result.csv"), "w")
writer = csv.writer(wfp)
writer.writerow(["protein_id", "predict_prob", "predict_label"])
# If it is a binary classification, write a separate file for the samples predicted to be positive
if args.task_type in ["binary-class", "binary_class"]:
wfp_positive = open(os.path.join(args.save_dir, "pred_result_positive.csv"), "w")
writer_positive = csv.writer(wfp_positive)
writer_positive.writerow(["protein_id", "predict_prob", "predict_label"])
for _ in dataset_iterator:
cur_dataset = dataset_gen.__next__()
id_list, X = cur_dataset
if args.model_type == "xgb":
X = xgboost.DMatrix(X)
probs = predict(args, model, X)
if args.task_type in ["multi-label", "multi_label"]:
preds = relevant_indexes((probs >= args.threshold).astype(int))
for idx in range(len(id_list)):
writer.writerow([id_list[idx], probs[idx], [label_map[label_idx] for label_idx in preds[idx]]])
elif args.task_type in ["multi-class", "multi_class"]:
preds = np.argmax(probs, axis=1)
for idx in range(len(id_list)):
writer.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
elif args.task_type == "regression":
pass # to do
elif args.task_type in ["binary-class", "binary_class"]:
preds = (probs >= args.threshold).astype(int)
for idx in range(len(id_list)):
writer.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
if preds[idx] == 1:
writer_positive.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
wfp_positive.flush()
else:
raise Exception("not support task_type=%s" % args.task_type)
wfp.flush()
wfp.close()
if args.task_type in ["binary-class", "binary_class"]:
wfp_positive.close()