in longform-qa/lfqa_utils.py [0:0]
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
q_ls = [q for q, a in qa_list]
a_ls = [a for q, a in qa_list]
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
q_ids, q_mask = (
torch.LongTensor(q_toks["input_ids"]).to(device),
torch.LongTensor(q_toks["attention_mask"]).to(device),
)
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
a_ids, a_mask = (
torch.LongTensor(a_toks["input_ids"]).to(device),
torch.LongTensor(a_toks["attention_mask"]).to(device),
)
lm_labels = a_ids[:, 1:].contiguous().clone()
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
model_inputs = {
"input_ids": q_ids,
"attention_mask": q_mask,
"decoder_input_ids": a_ids[:, :-1].contiguous(),
"lm_labels": lm_labels,
}
return model_inputs