src/protein_structure/merge_embedding_pdb_result.py (177 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
'''
*Copyright (c) 2023, Alibaba Group;
*Licensed under the Apache License, Version 2.0 (the "License");
*you may not use this file except in compliance with the License.
*You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*Unless required by applicable law or agreed to in writing, software
*distributed under the License is distributed on an "AS IS" BASIS,
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*See the License for the specific language governing permissions and
*limitations under the License.
@author: Hey
@email: sanyuan.**@**.com
@tel: 137****6540
@datetime: 2022/12/27 15:58
@project: DeepProtFunc
@file: merge_embedding_pdb_result.py
@desc: merge protein sequence、pdb filepath、embedding filepath info
'''
import os, csv, sys
sys.path.append(".")
sys.path.append("..")
sys.path.append("../..")
sys.path.append("../../src")
try:
from utils import write_fasta, fasta_reader, csv_reader, txt_reader
except ImportError:
from src.utils import write_fasta, fasta_reader, csv_reader, txt_reader
def load_file(filepath: str):
if filepath.endswith(".csv"):
reader = csv_reader(filepath, header_filter=True, header=True)
elif filepath.endswith(".txt"):
reader = txt_reader(filepath, header_filter=False, header=False)
else:
reader = fasta_reader(filepath)
# protein list
prot_id_list = set()
prot_id_2_seq = {}
for row in reader:
protein_id = row[0].strip()
seq = row[1].strip()
prot_id_list.add(protein_id)
prot_id_2_seq[protein_id] = seq
return prot_id_list, prot_id_2_seq
def load_pdb(protein_id_2_idx_fileptah, pdb_dirpath):
'''
load all PDB files
:param protein_id_2_idx_fileptah: mapping betwwen protein id and index filepath
:param pdb_dirpath: pdb saved dir
:return:
'''
if not os.path.exists(pdb_dirpath):
print("pdb dir: %s not exists!" % pdb_dirpath)
return None
protein_2_pdb_filepath = os.path.join(os.path.dirname(os.path.dirname(protein_id_2_idx_fileptah)),
".".join(os.path.basename(protein_id_2_idx_fileptah).split(".")[0:-1]) + "_protein_2_pdb.csv")
if os.path.exists(protein_2_pdb_filepath):
raise Exception("file: %s exists!" % protein_2_pdb_filepath)
prot_id_list = set()
done_prot_id_list = set()
with open(protein_2_pdb_filepath, "w") as wfp:
writer = csv.writer(wfp)
writer.writerow(["prot_id", "pdb_filename", "ptm", "mean_plddt"])
with open(protein_id_2_idx_fileptah, "r") as rfp:
reader = csv.reader(rfp)
cnt = 0
for row in reader:
cnt += 1
if cnt == 1:
continue
index, prot_id, seq_len, ptm, mean_plddt = row
prot_id_list.add(prot_id)
pdb_path = os.path.join(pdb_dirpath, "protein_%s.pdb" % index)
if os.path.exists(pdb_path):
writer.writerow([prot_id, "protein_%s.pdb" % index, ptm, mean_plddt])
done_prot_id_list.add(prot_id)
else:
writer.writerow([prot_id, None, ptm, mean_plddt])
print("pdb want to do: %d, done: %d, undo: %d" % (len(prot_id_list), len(done_prot_id_list), len(prot_id_list.difference(done_prot_id_list))))
return protein_2_pdb_filepath
def load_emb(protein_id_2_idx_filepath, embedding_dirpath):
if not os.path.exists(embedding_dirpath):
raise Exception("emb dir: %s not exists!" % embedding_dirpath)
protein_2_embedding_filepath = os.path.join(os.path.dirname(protein_id_2_idx_filepath[-1]),
".".join(os.path.basename(protein_id_2_idx_filepath[-1]).split(".")[0:-1]) + "_protein_2_emb.csv")
if os.path.exists(protein_2_embedding_filepath):
raise Exception("file: %s exists!" % protein_2_embedding_filepath)
prot_id_list = set()
done_prot_id_list = set()
with open(protein_2_embedding_filepath, "w") as wfp:
writer = csv.writer(wfp)
writer.writerow(["prot_id", "emb_filename"])
for cur_protein_id_2_idx_filepath in protein_id_2_idx_filepath:
with open(cur_protein_id_2_idx_filepath, "r") as rfp:
reader = csv.reader(rfp)
cnt = 0
for row in reader:
cnt += 1
if cnt == 1:
continue
index, prot_id = row
prot_id_list.add(prot_id)
emb_path = os.path.join(embedding_dirpath, "%s.pt" % index)
if os.path.exists(emb_path):
writer.writerow([prot_id, "%s.pt" % index])
done_prot_id_list.add(prot_id)
else:
raise Exception("emb_path :%s not exists" % emb_path)
if (cnt - 1) % 10000 == 0:
print("done %d" % (cnt - 1))
print("embedding want to do: %d, done: %d, undo: %d" % (len(prot_id_list), len(done_prot_id_list), len(prot_id_list.difference(done_prot_id_list))))
return protein_2_embedding_filepath
def merge(fasta_filepath, protein_2_pdb_filepath, protein_2_emb_filepath, label, source):
structure = {}
if protein_2_pdb_filepath and os.path.exists(protein_2_pdb_filepath):
with open(protein_2_pdb_filepath, "r") as rfp:
reader = csv.reader(rfp)
cnt = 0
for row in reader:
cnt += 1
if cnt == 1:
continue
prot_id, pdb_filename, ptm, mean_plddt = row
structure[prot_id] = [pdb_filename, ptm, mean_plddt]
embedding = {}
with open(protein_2_emb_filepath, "r") as rfp:
reader = csv.reader(rfp)
cnt = 0
for row in reader:
cnt += 1
if cnt == 1:
continue
prot_id, emb_filename = row
embedding[prot_id] = emb_filename
prot_id_list, prot_id_2_seq = load_file(fasta_filepath)
save_filepath = os.path.join(os.path.dirname(fasta_filepath), ".".join(os.path.basename(fasta_filepath).split(".")[:-1]) + "_with_pdb_emb.csv")
print("save path: %s" % save_filepath)
if os.path.exists(save_filepath):
raise Exception("file: %s exists!" % save_filepath)
with open(save_filepath, "w") as wfp:
writer = csv.writer(wfp)
writer.writerow(["prot_id", "seq", "seq_len", "pdb_filename", "ptm", "mean_plddt", "emb_filename", "label", "source"])
stats = {"seq": 0, "pdb": 0, "emb": 0}
for prot_id in prot_id_list:
seq = prot_id_2_seq[prot_id].strip().strip("*")
seq_len = len(seq)
pdb_filename, ptm, mean_plddt = None, None, None
emb_filename = None
stats["seq"] += 1
if prot_id in structure:
pdb_filename, ptm, mean_plddt = structure[prot_id]
stats["pdb"] += 1
if prot_id in embedding:
emb_filename = embedding[prot_id]
stats["emb"] += 1
writer.writerow([prot_id, seq, seq_len, pdb_filename, ptm, mean_plddt, emb_filename, label, source])
print("stats: ")
print(stats)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--fasta_filepath", default=None, required=True, type=str, help="fasta filepath")
parser.add_argument("--protein_id_2_pdb_idx_filepath", default=None, type=str, help="the protein id to pdb file index name filepath")
parser.add_argument("--pdb_dirpath", default=None, type=str, help="pdb file dirpath(every filename is index)")
parser.add_argument("--protein_id_2_emb_idx_filepath", default=None, required=True, type=str, help="the protein id to embedding file index name filepath")
parser.add_argument("--emb_dirpath", default=None, type=str, help="embedding file dirpath(every filename is index)")
parser.add_argument("--label", default=None, type=int, help="this dataset label")
parser.add_argument("--source", default=None, type=str, help="")
args = parser.parse_args()
if __name__ == "__main__":
fasta_filepath = args.fasta_filepath
protein_id_2_pdb_idx_filepath = args.protein_id_2_pdb_idx_filepath
pdb_dirpath = args.pdb_dirpath
protein_id_2_emb_idx_filepath = args.protein_id_2_emb_idx_filepath
embedding_dirpath = args.embedding_dirpath
protein_2_pdb_filepath = load_pdb(protein_id_2_pdb_idx_filepath, pdb_dirpath)
protein_2_emb_filepath = load_emb(protein_id_2_emb_idx_filepath, embedding_dirpath)
merge(fasta_filepath, protein_2_pdb_filepath, protein_2_emb_filepath, label=args.label, source=args.source)
merge(fasta_filepath, protein_2_pdb_filepath, protein_2_emb_filepath, label=args.labe, source=args.source)