def main()

in src/result_process/process_predict_result.py [0:0]


def main(input_fasta_filepaths, result_filedir, result_filenames, merge_dirname):
    if isinstance(input_fasta_filepaths, str):
        input_fasta_filepaths = [input_fasta_filepaths]
    total = 0
    fasta_id_set = set()
    repeat_cnt = 0
    for input_fasta_filepath in input_fasta_filepaths:
        cnt = 0
        for row in fasta_reader(input_fasta_filepath):
            cnt += 1
            if row[0] in fasta_id_set:
                repeat_cnt += 1
            fasta_id_set.add(row[0])
        total += cnt
    print("total: %d, protein id: %d, repeat cnt: %d" % (total, len(fasta_id_set), repeat_cnt))
    # assert len(fasta_id_set) == total
    print("fasta num: %d" % total)
    writer_dir = os.path.join(result_filedir, merge_dirname)
    if not os.path.exists(writer_dir):
        os.makedirs(writer_dir)
    all_result_wfp = open(os.path.join(writer_dir, "predict_all.csv"), "w")
    all_result_wtiter = csv.writer(all_result_wfp)
    all_result_wtiter.writerow(["id", "prob", "pred"])
    positive_result_wfp = open(os.path.join(writer_dir, "predict_positive.csv"), "w")
    positive_result_writer = csv.writer(positive_result_wfp)
    positive_result_writer.writerow(["id", "prob", "pred"])
    sequence_list = []
    total = 0
    pred_fasta_id = {}
    for filename in result_filenames:
        filepath = os.path.join(result_filedir, filename)
        with open(filepath, "r") as rfp:
            reader = csv.reader(rfp)
            cnt = 0
            for row in reader:
                cnt += 1
                if cnt == 1:
                    continue
                protein_id, seq, predict_prob, predict_label, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row
                if protein_id in pred_fasta_id:
                    pred_fasta_id[protein_id].append((predict_prob, predict_label))
                else:
                    pred_fasta_id[protein_id] = [(predict_prob, predict_label)]
                fasta_id_set.remove(protein_id)

                protein_id = protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id
                if int(predict_label) == 1:
                    sequence_list.append(SeqRecord(
                        Seq(seq, None),
                        id=protein_id,
                        description=""))
                    positive_result_writer.writerow([protein_id, predict_prob, predict_label])
                all_result_wtiter.writerow([protein_id, predict_prob, predict_label])
            total += cnt - 1
    print("predict num: %d" % total)
    if len(fasta_id_set) > 0:
        print("not done set:")
        print(fasta_id_set)
    pred_fasta_id = [item for item in pred_fasta_id.items() if len(item[1]) > 1]
    if len(pred_fasta_id) > 0:
        print("done >= 2 times set:")
        print(pred_fasta_id)
    print("total: %d, positive: %d, p rate: %f" % (total, len(sequence_list), len(sequence_list)/total))
    write_fasta(os.path.join(writer_dir, "predict_positive.fasta"), sequence_list)
    all_result_wfp.close()
    positive_result_wfp.close()