in mdr/qa/qa_dataset.py [0:0]
def qa_collate(samples, pad_id=0):
if len(samples) == 0:
return {}
batch = {
'input_ids': collate_tokens([s["encodings"]['input_ids'] for s in samples], pad_id),
'attention_mask': collate_tokens([s["encodings"]['attention_mask'] for s in samples], 0),
'paragraph_mask': collate_tokens([s['paragraph_mask'] for s in samples], 0),
'label': collate_tokens([s["label"] for s in samples], -1),
"sent_offsets": collate_tokens([s["sent_offsets"] for s in samples], 0),
}
# training labels
if "starts" in samples[0]:
batch["starts"] = collate_tokens([s['starts'] for s in samples], -1)
batch["ends"] = collate_tokens([s['ends'] for s in samples], -1)
# batch["ans_types"] = collate_tokens([s['ans_type'] for s in samples], -1)
batch["sent_labels"] = collate_tokens([s['sent_labels'] for s in samples], 0)
batch["ans_covered"] = collate_tokens([s['ans_covered'] for s in samples], 0)
# roberta does not use token_type_ids
if "token_type_ids" in samples[0]["encodings"]:
batch["token_type_ids"] = collate_tokens([s["encodings"]['token_type_ids']for s in samples], 0)
batched = {
"qids": [s["qid"] for s in samples],
"passages": [s["passages"] for s in samples],
"gold_answer": [s["gold_answer"] for s in samples],
"sp_gold": [s["sp_gold"] for s in samples],
"para_offsets": [s["para_offset"] for s in samples],
"net_inputs": batch,
}
# for answer extraction
if "doc_tokens" in samples[0]:
batched["doc_tokens"] = [s["doc_tokens"] for s in samples]
batched["tok_to_orig_index"] = [s["tok_to_orig_index"] for s in samples]
batched["wp_tokens"] = [s["wp_tokens"] for s in samples]
return batched