def run()

in src/baselines/predict.py [0:0]


def run(args):
    '''
    main function
    :param args:
    :return:
    '''
    # load data
    total_records_size = calc_dataset_size(args.data_path)
    dataset_num = (total_records_size + (args.split_size - 1)) // args.split_size
    dataset_iterator = trange(int(dataset_num), desc="Dataset Iterator", disable=False)
    dataset_gen = load_dataset(args.data_path, split_size=args.split_size)
    # load model
    model = load_model(args.model_path, args.model_type)
    # label labels
    label_list = load_labels(args.label_filepath, header=True)
    label_map = {idx: name for idx, name in enumerate(label_list)}
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    wfp = open(os.path.join(args.save_dir, "pred_result.csv"), "w")
    writer = csv.writer(wfp)
    writer.writerow(["protein_id", "predict_prob", "predict_label"])

    # If it is a binary classification, write a separate file for the samples predicted to be positive
    if args.task_type in ["binary-class", "binary_class"]:
        wfp_positive = open(os.path.join(args.save_dir, "pred_result_positive.csv"), "w")
        writer_positive = csv.writer(wfp_positive)
        writer_positive.writerow(["protein_id", "predict_prob", "predict_label"])
    for _ in dataset_iterator:
        cur_dataset = dataset_gen.__next__()
        id_list, X = cur_dataset
        if args.model_type == "xgb":
            X = xgboost.DMatrix(X)
        probs = predict(args, model, X)
        if args.task_type in ["multi-label", "multi_label"]:
            preds = relevant_indexes((probs >= args.threshold).astype(int))
            for idx in range(len(id_list)):
                writer.writerow([id_list[idx], probs[idx], [label_map[label_idx] for label_idx in preds[idx]]])
        elif args.task_type in ["multi-class", "multi_class"]:
            preds = np.argmax(probs, axis=1)
            for idx in range(len(id_list)):
                writer.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
        elif args.task_type == "regression":
            pass # to do
        elif args.task_type in ["binary-class", "binary_class"]:
            preds = (probs >= args.threshold).astype(int)
            for idx in range(len(id_list)):
                writer.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
                if preds[idx] == 1:
                    writer_positive.writerow([id_list[idx], probs[idx], label_map[preds[idx]]])
            wfp_positive.flush()
        else:
            raise Exception("not support task_type=%s" % args.task_type)
        wfp.flush()
    wfp.close()
    if args.task_type in ["binary-class", "binary_class"]:
        wfp_positive.close()