def preprocess_seqs()

in lmgvp/data_loaders.py [0:0]


def preprocess_seqs(tokenizer, dataset):
    """Preprocess seq in dataset and bind the input_ids, attention_mask.

    Args:
        tokenizer: hugging face artifact. Tokenization to be used in the sequence.
        dataset: Dictionary containing the GVP dataset of proteins.

    Return:
        Input dataset with `input_ids` and `attention_mask`
    """
    seqs = [prep_seq(rec["seq"]) for rec in dataset]
    encodings = tokenizer(seqs, return_tensors="pt", padding=True)
    # add input_ids, attention_mask to the json records
    for i, rec in enumerate(dataset):
        rec["input_ids"] = encodings["input_ids"][i]
        rec["attention_mask"] = encodings["attention_mask"][i]
    return dataset