def predict_embedding()

in src/protein_structure/predict_structure.py [0:0]


def predict_embedding(sample, trunc_type, embedding_type, repr_layers=[-1], truncation_seq_length=4094, device=None):
    '''
    use sequence to predict protein embedding matrix or vector(bos)
    :param sample: [protein_id, protein_sequence]
    :param trunc_type:
    :param embedding_type: bos or representations
    :param repr_layers: [-1]
    :param truncation_seq_length: [4094,2046,1982,1790,1534,1278,1150,1022]
    :param device:
    :return: embedding, processed_seq_len
    '''
    global model, alphabet
    assert embedding_type in ["bos", "representations", "matrix"]
    protein_id, protein_seq = sample[0], sample[1]
    if len(protein_seq) > truncation_seq_length:
        if trunc_type == "left":
            protein_seq = protein_seq[-truncation_seq_length:]
        else:
            protein_seq = protein_seq[:truncation_seq_length]
    if model is None or alphabet is None:
        model, alphabet = pretrained.load_model_and_alphabet("esm2_t36_3B_UR50D")
    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    else:
        model_device = next(model.parameters()).device
        if device != model_device:
            model = model.to(device)
    """
    if torch.cuda.is_available():
        model = model.cuda()
        # print("Transferred model to GPU")
    """
    converter = BatchConverter(alphabet, truncation_seq_length)
    protein_ids, raw_seqs, tokens = converter([[protein_id, protein_seq]])
    with torch.no_grad():
        # if torch.cuda.is_available():
        # tokens = tokens.to(device="cuda", non_blocking=True)
        tokens = tokens.to(device=device, non_blocking=True)
        try:
            out = model(tokens, repr_layers=repr_layers, return_contacts=False)
            truncate_len = min(truncation_seq_length, len(raw_seqs[0]))
            if embedding_type in ["representations", "matrix"]:
                embedding = out["representations"][36].to(device="cpu")[0, 1: truncate_len + 1].clone().numpy()
            else:
                embedding = out["representations"][36].to(device="cpu")[0, 0].clone().numpy()
            return embedding, protein_seq
        except RuntimeError as e:
            if e.args[0].startswith("CUDA out of memory"):
                print(f"Failed (CUDA out of memory) on sequence {sample[0]} of length {len(sample[1])}.")
                print("Please reduce the 'truncation_seq_length'")
            if device.type == "cpu":
                # insufficient cpu memory
                raise Exception(e)
            else:
                # failure in GPU, return None to continue using CPU
                return None, None