in src/result_process/process_predict_result.py [0:0]
def main(input_fasta_filepaths, result_filedir, result_filenames, merge_dirname):
if isinstance(input_fasta_filepaths, str):
input_fasta_filepaths = [input_fasta_filepaths]
total = 0
fasta_id_set = set()
repeat_cnt = 0
for input_fasta_filepath in input_fasta_filepaths:
cnt = 0
for row in fasta_reader(input_fasta_filepath):
cnt += 1
if row[0] in fasta_id_set:
repeat_cnt += 1
fasta_id_set.add(row[0])
total += cnt
print("total: %d, protein id: %d, repeat cnt: %d" % (total, len(fasta_id_set), repeat_cnt))
# assert len(fasta_id_set) == total
print("fasta num: %d" % total)
writer_dir = os.path.join(result_filedir, merge_dirname)
if not os.path.exists(writer_dir):
os.makedirs(writer_dir)
all_result_wfp = open(os.path.join(writer_dir, "predict_all.csv"), "w")
all_result_wtiter = csv.writer(all_result_wfp)
all_result_wtiter.writerow(["id", "prob", "pred"])
positive_result_wfp = open(os.path.join(writer_dir, "predict_positive.csv"), "w")
positive_result_writer = csv.writer(positive_result_wfp)
positive_result_writer.writerow(["id", "prob", "pred"])
sequence_list = []
total = 0
pred_fasta_id = {}
for filename in result_filenames:
filepath = os.path.join(result_filedir, filename)
with open(filepath, "r") as rfp:
reader = csv.reader(rfp)
cnt = 0
for row in reader:
cnt += 1
if cnt == 1:
continue
protein_id, seq, predict_prob, predict_label, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row
if protein_id in pred_fasta_id:
pred_fasta_id[protein_id].append((predict_prob, predict_label))
else:
pred_fasta_id[protein_id] = [(predict_prob, predict_label)]
fasta_id_set.remove(protein_id)
protein_id = protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id
if int(predict_label) == 1:
sequence_list.append(SeqRecord(
Seq(seq, None),
id=protein_id,
description=""))
positive_result_writer.writerow([protein_id, predict_prob, predict_label])
all_result_wtiter.writerow([protein_id, predict_prob, predict_label])
total += cnt - 1
print("predict num: %d" % total)
if len(fasta_id_set) > 0:
print("not done set:")
print(fasta_id_set)
pred_fasta_id = [item for item in pred_fasta_id.items() if len(item[1]) > 1]
if len(pred_fasta_id) > 0:
print("done >= 2 times set:")
print(pred_fasta_id)
print("total: %d, positive: %d, p rate: %f" % (total, len(sequence_list), len(sequence_list)/total))
write_fasta(os.path.join(writer_dir, "predict_positive.fasta"), sequence_list)
all_result_wfp.close()
positive_result_wfp.close()