in src/protein_structure/embedding_from_esmfold.py [0:0]
def main(args):
model, alphabet = pretrained.load_model_and_alphabet(args.model_name)
model.eval()
if isinstance(model, MSATransformer):
raise ValueError(
"This script currently does not handle models with MSA input (MSA Transformer)."
)
if torch.cuda.is_available() and not args.nogpu:
model = model.cuda()
# print("Transferred model to GPU")
dataset = FastaBatchedDataset.from_file(args.file)
'''
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches
)
'''
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length))
print(f"Read {args.file} with {len(dataset)} sequences")
os.makedirs(args.output_dir, exist_ok=True)
return_contacts = "contacts" in args.include
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]
had = False
if os.path.exists(args.uncompleted_file):
uncompleted_wfp = open(args.uncompleted_file, "a+")
else:
uncompleted_wfp = open(args.uncompleted_file, "w")
if os.path.exists(args.fasta_id_2_idx_file):
fasta_id_2_idx_wfp = open(args.fasta_id_2_idx_file, "a+")
had = True
else:
fasta_id_2_idx_wfp = open(args.fasta_id_2_idx_file, "w")
fasta_id_2_idx_writer = csv.writer(fasta_id_2_idx_wfp)
if not had:
fasta_id_2_idx_writer.writerow(["index", "uuid"])
protein_idx = args.begin_uuid_index
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(data_loader, desc="Iteration")):
# strs ori sequence,toks: processed sequence(such as truncating,padding)
protein_ids, strs, toks = batch
protein_ids = [">" + v.strip() if v and v[0] != ">" else v.strip() for v in protein_ids]
if torch.cuda.is_available() and not args.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
try:
out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
except RuntimeError as e:
if e.args[0].startswith("CUDA out of memory"):
if len(strs) > 1:
print(
f"Failed (CUDA out of memory) to predict batch of size {len(strs)}. "
"Try lowering `--toks_per_batch."
)
else:
print(
f"Failed (CUDA out of memory) on sequence {protein_ids[0]} of length {len(strs[0])}."
)
for idx, v in enumerate(protein_ids):
uncompleted_wfp.write("%s,%d\n" % (v, len(strs[idx])))
uncompleted_wfp.flush()
continue
# logits = out["logits"].to(device="cpu")
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for idx, protein_id in enumerate(protein_ids):
protein_idx += 1
cur_output_file = os.path.join(args.output_dir, "%s.pt" % protein_idx)
result = {"protein_id": protein_id, "seq": strs[idx], "seq_len": len(strs[idx]), "max_len": args.truncation_seq_length}
truncate_len = min(args.truncation_seq_length, len(strs[idx]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in args.include:
result["representations"] = {
layer: t[idx, 1: truncate_len + 1].clone() for layer, t in representations.items()
}
if "mean" in args.include:
result["mean_representations"] = {
layer: t[idx, 1: truncate_len + 1].mean(0).clone() for layer, t in representations.items()
}
if "bos" in args.include:
result["bos_representations"] = {
layer: t[idx, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[idx, 1: truncate_len + 1, 1: truncate_len + 1].clone()
torch.save(
result,
cur_output_file,
)
fasta_id_2_idx_writer.writerow([protein_idx, protein_id])
fasta_id_2_idx_wfp.flush()
uncompleted_wfp.close()
fasta_id_2_idx_wfp.close()