in longform-qa/lfqa_utils.py [0:0]
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
return (q_ids, q_mask, a_ids, a_mask)