def mhop_collate()

in mdr/retrieval/data/sp_datasets.py [0:0]


def mhop_collate(samples, pad_id=0):
    batch = {
            'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], pad_id),
            'q_mask':collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0),
            'c_input_ids_1': collate_tokens([s["pos_codes_1"]["input_ids"].view(-1) for s in samples], pad_id),
            'c_mask_1': collate_tokens([s["pos_codes_1"]["attention_mask"].view(-1) for s in samples], 0),
            'c_input_ids_2': collate_tokens([s["pos_codes_2"]["input_ids"].view(-1) for s in samples], pad_id),
            'c_mask_2': collate_tokens([s["pos_codes_2"]["attention_mask"].view(-1) for s in samples], 0),
            'neg_input_ids_1': collate_tokens([s["neg_codes_1"]["input_ids"].view(-1) for s in samples], pad_id),
            'neg_mask_1': collate_tokens([s["neg_codes_1"]["attention_mask"].view(-1) for s in samples], 0),
            'neg_input_ids_2': collate_tokens([s["neg_codes_2"]["input_ids"].view(-1) for s in samples], pad_id),
            'neg_mask_2': collate_tokens([s["neg_codes_2"]["attention_mask"].view(-1) for s in samples], 0),
            }

    if "token_type_ids" in samples[0]["q_codes"]:
        batch.update({
            'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0),
            'c_type_ids_1': collate_tokens([s["pos_codes_1"]["token_type_ids"].view(-1) for s in samples], 0),
            'c_type_ids_2': collate_tokens([s["pos_codes_2"]["token_type_ids"].view(-1) for s in samples], 0),
            'neg_type_ids_1': collate_tokens([s["neg_codes_1"]["token_type_ids"].view(-1) for s in samples], 0),
            'neg_type_ids_2': collate_tokens([s["neg_codes_2"]["token_type_ids"].view(-1) for s in samples], 0),
        })
    
    return batch