def main()

in src/protein_structure/embedding_from_esmfold.py [0:0]


def main(args):
    model, alphabet = pretrained.load_model_and_alphabet(args.model_name)
    model.eval()
    if isinstance(model, MSATransformer):
        raise ValueError(
            "This script currently does not handle models with MSA input (MSA Transformer)."
        )
    if torch.cuda.is_available() and not args.nogpu:
        model = model.cuda()
        # print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(args.file)
    '''
    batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches
    )
    '''
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length))

    print(f"Read {args.file} with {len(dataset)} sequences")
    os.makedirs(args.output_dir,  exist_ok=True)
    return_contacts = "contacts" in args.include

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]

    had = False
    if os.path.exists(args.uncompleted_file):
        uncompleted_wfp = open(args.uncompleted_file, "a+")
    else:
        uncompleted_wfp = open(args.uncompleted_file, "w")
    if os.path.exists(args.fasta_id_2_idx_file):
        fasta_id_2_idx_wfp = open(args.fasta_id_2_idx_file, "a+")
        had = True
    else:
        fasta_id_2_idx_wfp = open(args.fasta_id_2_idx_file, "w")
    fasta_id_2_idx_writer = csv.writer(fasta_id_2_idx_wfp)
    if not had:
        fasta_id_2_idx_writer.writerow(["index", "uuid"])
    protein_idx = args.begin_uuid_index
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Iteration")):
            # strs ori sequence,toks: processed sequence(such as truncating,padding)
            protein_ids, strs, toks = batch
            protein_ids = [">" + v.strip() if v and v[0] != ">" else v.strip() for v in protein_ids]
            if torch.cuda.is_available() and not args.nogpu:
                toks = toks.to(device="cuda", non_blocking=True)
            try:
                out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)
            except RuntimeError as e:
                if e.args[0].startswith("CUDA out of memory"):
                    if len(strs) > 1:
                        print(
                            f"Failed (CUDA out of memory) to predict batch of size {len(strs)}. "
                            "Try lowering `--toks_per_batch."
                        )
                    else:
                        print(
                            f"Failed (CUDA out of memory) on sequence {protein_ids[0]} of length {len(strs[0])}."
                        )
                    for idx, v in enumerate(protein_ids):
                        uncompleted_wfp.write("%s,%d\n" % (v, len(strs[idx])))
                        uncompleted_wfp.flush()
                continue

            # logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")

            for idx, protein_id in enumerate(protein_ids):
                protein_idx += 1
                cur_output_file = os.path.join(args.output_dir,  "%s.pt" % protein_idx)
                result = {"protein_id": protein_id, "seq": strs[idx], "seq_len": len(strs[idx]), "max_len": args.truncation_seq_length}

                truncate_len = min(args.truncation_seq_length, len(strs[idx]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in args.include:
                    result["representations"] = {
                        layer: t[idx, 1: truncate_len + 1].clone() for layer, t in representations.items()
                    }
                if "mean" in args.include:
                    result["mean_representations"] = {
                        layer: t[idx, 1: truncate_len + 1].mean(0).clone() for layer, t in representations.items()
                    }
                if "bos" in args.include:
                    result["bos_representations"] = {
                        layer: t[idx, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[idx, 1: truncate_len + 1, 1: truncate_len + 1].clone()

                torch.save(
                    result,
                    cur_output_file,
                )
                fasta_id_2_idx_writer.writerow([protein_idx, protein_id])
                fasta_id_2_idx_wfp.flush()
    uncompleted_wfp.close()
    fasta_id_2_idx_wfp.close()