src/result_process/process_predict_result.py (162 lines of code) (raw):

#!/usr/bin/env python # encoding: utf-8 ''' *Copyright (c) 2023, Alibaba Group; *Licensed under the Apache License, Version 2.0 (the "License"); *you may not use this file except in compliance with the License. *You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 *Unless required by applicable law or agreed to in writing, software *distributed under the License is distributed on an "AS IS" BASIS, *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *See the License for the specific language governing permissions and *limitations under the License. @author: Hey @email: sanyuan.**@**.com @tel: 137****6540 @datetime: 2023/1/18 14:03 @project: DeepProtFunc @file: process_predict_result.py @desc: process the results of prediction ''' import os, csv import io, textwrap, itertools from Bio import SeqIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord def fasta_reader(handle, width=None): """ Reads a FASTA file, yielding header, sequence pairs for each sequence recovered args: :handle (str, pathliob.Path, or file pointer) - fasta to read from :width (int or None) - formats the sequence to have max `width` character per line. If <= 0, processed as None. If None, there is no max width. yields: :(header, sequence) tuples returns: :None """ FASTA_STOP_CODON = "*" handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r') width = width if isinstance(width, int) and width > 0 else None try: for is_header, group in itertools.groupby(handle, lambda line: line.startswith(">")): if is_header: header = group.__next__().strip() else: seq = ''.join(line.strip() for line in group).strip().rstrip(FASTA_STOP_CODON) if width is not None: seq = textwrap.fill(seq, width) yield header, seq finally: if not handle.closed: handle.close() def write_fasta(filepath, sequences): ''' write fasta file :param filepath: the save filepath :param sequences: fasta sequence list (each item: [id, seq]) :return: ''' with open(filepath, "w") as output_handle: for sequence in sequences: SeqIO.write(sequence, output_handle, "fasta") def main(input_fasta_filepaths, result_filedir, result_filenames, merge_dirname): if isinstance(input_fasta_filepaths, str): input_fasta_filepaths = [input_fasta_filepaths] total = 0 fasta_id_set = set() repeat_cnt = 0 for input_fasta_filepath in input_fasta_filepaths: cnt = 0 for row in fasta_reader(input_fasta_filepath): cnt += 1 if row[0] in fasta_id_set: repeat_cnt += 1 fasta_id_set.add(row[0]) total += cnt print("total: %d, protein id: %d, repeat cnt: %d" % (total, len(fasta_id_set), repeat_cnt)) # assert len(fasta_id_set) == total print("fasta num: %d" % total) writer_dir = os.path.join(result_filedir, merge_dirname) if not os.path.exists(writer_dir): os.makedirs(writer_dir) all_result_wfp = open(os.path.join(writer_dir, "predict_all.csv"), "w") all_result_wtiter = csv.writer(all_result_wfp) all_result_wtiter.writerow(["id", "prob", "pred"]) positive_result_wfp = open(os.path.join(writer_dir, "predict_positive.csv"), "w") positive_result_writer = csv.writer(positive_result_wfp) positive_result_writer.writerow(["id", "prob", "pred"]) sequence_list = [] total = 0 pred_fasta_id = {} for filename in result_filenames: filepath = os.path.join(result_filedir, filename) with open(filepath, "r") as rfp: reader = csv.reader(rfp) cnt = 0 for row in reader: cnt += 1 if cnt == 1: continue protein_id, seq, predict_prob, predict_label, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source = row if protein_id in pred_fasta_id: pred_fasta_id[protein_id].append((predict_prob, predict_label)) else: pred_fasta_id[protein_id] = [(predict_prob, predict_label)] fasta_id_set.remove(protein_id) protein_id = protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id if int(predict_label) == 1: sequence_list.append(SeqRecord( Seq(seq, None), id=protein_id, description="")) positive_result_writer.writerow([protein_id, predict_prob, predict_label]) all_result_wtiter.writerow([protein_id, predict_prob, predict_label]) total += cnt - 1 print("predict num: %d" % total) if len(fasta_id_set) > 0: print("not done set:") print(fasta_id_set) pred_fasta_id = [item for item in pred_fasta_id.items() if len(item[1]) > 1] if len(pred_fasta_id) > 0: print("done >= 2 times set:") print(pred_fasta_id) print("total: %d, positive: %d, p rate: %f" % (total, len(sequence_list), len(sequence_list)/total)) write_fasta(os.path.join(writer_dir, "predict_positive.fasta"), sequence_list) all_result_wfp.close() positive_result_wfp.close() if __name__ == "__main__": ''' input_fasta_filepath = "/mnt/****/biodata/20230108-to-Ali/00self_sequecing_500aa.pep" result_filedir = "../predicts/rdrp_40/protein/binary_class/sefn/20230107005818/" result_filenames = ["00self_sequecing_500aa_001_with_pdb_emb/pred_result.csv", "00self_sequecing_500aa_002_with_pdb_emb/pred_result.csv", "00self_sequecing_500aa_003_with_pdb_emb/pred_result.csv"] merge_dirname = "00self_sequecing_500aa" main(input_fasta_filepath, result_filedir, result_filenames, merge_dirname) ''' input_fasta_filepath = ["/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_001.pep", "/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_002.pep", "/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_003.pep", "/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_004.pep", "/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_005.pep", "/mnt2/****/biodata/20221123-to-Ali/all_500aa.pep.split/all_500aa.part_006.pep"] result_filedir = "../predicts/rdrp_40/protein/binary_class/sefn/20230107005818/checkpoint-95000/" result_filenames = ["all_500aa.part_003_with_pdb_emb/pred_result.csv", "all_500aa.part_004_with_pdb_emb/pred_result.csv", "all_500aa.part_005_with_pdb_emb/pred_result.csv", "all_500aa.part_006_with_pdb_emb/pred_result.csv", "all_500aa.part_001_001_with_pdb_emb/pred_result.csv", "all_500aa.part_001_002_with_pdb_emb/pred_result.csv", "all_500aa.part_001_003_with_pdb_emb/pred_result.csv", "all_500aa.part_001_004_with_pdb_emb/pred_result.csv", "all_500aa.part_002_001_with_pdb_emb/pred_result.csv", "all_500aa.part_002_002_with_pdb_emb/pred_result.csv", "all_500aa.part_002_003_with_pdb_emb/pred_result.csv", "all_500aa.part_002_004_with_pdb_emb/pred_result.csv", "all_500aa.part_003_001_with_pdb_emb/pred_result.csv", "all_500aa.part_003_002_with_pdb_emb/pred_result.csv", "all_500aa.part_003_003_with_pdb_emb/pred_result.csv", "all_500aa.part_003_004_with_pdb_emb/pred_result.csv", "all_500aa.part_004_001_with_pdb_emb/pred_result.csv", "all_500aa.part_004_002_with_pdb_emb/pred_result.csv", "all_500aa.part_004_003_with_pdb_emb/pred_result.csv", "all_500aa.part_004_004_with_pdb_emb/pred_result.csv", "all_500aa.part_005_001_with_pdb_emb/pred_result.csv", "all_500aa.part_005_002_with_pdb_emb/pred_result.csv", "all_500aa.part_005_003_with_pdb_emb/pred_result.csv", "all_500aa.part_005_004_with_pdb_emb/pred_result.csv", "all_500aa.part_006_001_with_pdb_emb/pred_result.csv", "all_500aa.part_006_002_with_pdb_emb/pred_result.csv", "all_500aa.part_006_003_with_pdb_emb/pred_result.csv", "all_500aa.part_006_004_with_pdb_emb/pred_result.csv" ] merge_dirname = "all_500aa.pep.split" main(input_fasta_filepath, result_filedir, result_filenames, merge_dirname)