in src/protein_structure/predict_structure.py [0:0]
def predict_embedding(sample, trunc_type, embedding_type, repr_layers=[-1], truncation_seq_length=4094, device=None):
'''
use sequence to predict protein embedding matrix or vector(bos)
:param sample: [protein_id, protein_sequence]
:param trunc_type:
:param embedding_type: bos or representations
:param repr_layers: [-1]
:param truncation_seq_length: [4094,2046,1982,1790,1534,1278,1150,1022]
:param device:
:return: embedding, processed_seq_len
'''
global model, alphabet
assert embedding_type in ["bos", "representations", "matrix"]
protein_id, protein_seq = sample[0], sample[1]
if len(protein_seq) > truncation_seq_length:
if trunc_type == "left":
protein_seq = protein_seq[-truncation_seq_length:]
else:
protein_seq = protein_seq[:truncation_seq_length]
if model is None or alphabet is None:
model, alphabet = pretrained.load_model_and_alphabet("esm2_t36_3B_UR50D")
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
model.eval()
if device is None:
device = next(model.parameters()).device
else:
model_device = next(model.parameters()).device
if device != model_device:
model = model.to(device)
"""
if torch.cuda.is_available():
model = model.cuda()
# print("Transferred model to GPU")
"""
converter = BatchConverter(alphabet, truncation_seq_length)
protein_ids, raw_seqs, tokens = converter([[protein_id, protein_seq]])
with torch.no_grad():
# if torch.cuda.is_available():
# tokens = tokens.to(device="cuda", non_blocking=True)
tokens = tokens.to(device=device, non_blocking=True)
try:
out = model(tokens, repr_layers=repr_layers, return_contacts=False)
truncate_len = min(truncation_seq_length, len(raw_seqs[0]))
if embedding_type in ["representations", "matrix"]:
embedding = out["representations"][36].to(device="cpu")[0, 1: truncate_len + 1].clone().numpy()
else:
embedding = out["representations"][36].to(device="cpu")[0, 0].clone().numpy()
return embedding, protein_seq
except RuntimeError as e:
if e.args[0].startswith("CUDA out of memory"):
print(f"Failed (CUDA out of memory) on sequence {sample[0]} of length {len(sample[1])}.")
print("Please reduce the 'truncation_seq_length'")
if device.type == "cpu":
# insufficient cpu memory
raise Exception(e)
else:
# failure in GPU, return None to continue using CPU
return None, None