def qa_collate()

in mdr/qa/qa_dataset.py [0:0]


def qa_collate(samples, pad_id=0):
    if len(samples) == 0:
        return {}

    batch = {
        'input_ids': collate_tokens([s["encodings"]['input_ids'] for s in samples], pad_id),
        'attention_mask': collate_tokens([s["encodings"]['attention_mask'] for s in samples], 0),
        'paragraph_mask': collate_tokens([s['paragraph_mask'] for s in samples], 0),
        'label': collate_tokens([s["label"] for s in samples], -1),
        "sent_offsets": collate_tokens([s["sent_offsets"] for s in samples], 0),
        }

    # training labels
    if "starts" in samples[0]:
        batch["starts"] = collate_tokens([s['starts'] for s in samples], -1)
        batch["ends"] = collate_tokens([s['ends'] for s in samples], -1)
        # batch["ans_types"] = collate_tokens([s['ans_type'] for s in samples], -1)
        batch["sent_labels"] = collate_tokens([s['sent_labels'] for s in samples], 0)
        batch["ans_covered"] = collate_tokens([s['ans_covered'] for s in samples], 0)

    # roberta does not use token_type_ids
    if "token_type_ids" in samples[0]["encodings"]:
        batch["token_type_ids"] = collate_tokens([s["encodings"]['token_type_ids']for s in samples], 0)

    batched = {
        "qids": [s["qid"] for s in samples],
        "passages": [s["passages"] for s in samples],
        "gold_answer": [s["gold_answer"] for s in samples],
        "sp_gold": [s["sp_gold"] for s in samples],
        "para_offsets": [s["para_offset"] for s in samples],
        "net_inputs": batch,
    }

    # for answer extraction
    if "doc_tokens" in samples[0]:
        batched["doc_tokens"] = [s["doc_tokens"] for s in samples]
        batched["tok_to_orig_index"] = [s["tok_to_orig_index"] for s in samples]
        batched["wp_tokens"] = [s["wp_tokens"] for s in samples]

    return batched