in mdr/retrieval/data/unified_dataset.py [0:0]
def nq_unified_collate(samples, pad_id=0):
if len(samples) == 0:
return {}
batch = {
'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], pad_id),
'q_mask':collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0),
'q_neg1_input_ids': collate_tokens([s["q_neg1_codes"]["input_ids"].view(-1) for s in samples], 0),
'q_neg1_mask':collate_tokens([s["q_neg1_codes"]["attention_mask"].view(-1) for s in samples], 0),
'c_input_ids': collate_tokens([s["pos_codes"]["input_ids"] for s in samples], 0),
'c_mask': collate_tokens([s["pos_codes"]["attention_mask"] for s in samples], 0),
'neg_input_ids': collate_tokens([s["neg_codes"]["input_ids"] for s in samples], 0),
'neg_mask': collate_tokens([s["neg_codes"]["attention_mask"] for s in samples], 0),
'dense_neg1_input_ids': collate_tokens([s["dense_neg1_codes"]["input_ids"] for s in samples], 0),
'dense_neg1_mask': collate_tokens([s["dense_neg1_codes"]["attention_mask"] for s in samples], 0),
'dense_neg2_input_ids': collate_tokens([s["dense_neg2_codes"]["input_ids"] for s in samples], 0),
'dense_neg2_mask': collate_tokens([s["dense_neg2_codes"]["attention_mask"] for s in samples], 0),
}
if "token_type_ids" in samples[0]["q_codes"]:
batch.update({
'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0),
'c_type_ids': collate_tokens([s["pos_codes"]["token_type_ids"] for s in samples], 0),
"q_neg1_type_ids": collate_tokens([s["q_neg1_codes"]["token_type_ids"].view(-1) for s in samples], 0),
'neg_type_ids': collate_tokens([s["neg_codes"]["token_type_ids"] for s in samples], 0),
'dense_neg1_type_ids': collate_tokens([s["dense_neg1_codes"]["token_type_ids"] for s in samples], 0),
'dense_neg2_type_ids': collate_tokens([s["dense_neg2_codes"]["token_type_ids"] for s in samples], 0),
})
return batch