def mhop_collate()

in mdr/retrieval/data/mhop_dataset.py [0:0]


def mhop_collate(samples, pad_id=0):
    if len(samples) == 0:
        return {}
    
    batch = {
            'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], 0),
            'q_mask':collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0),

            'q_sp_input_ids': collate_tokens([s["q_sp_codes"]["input_ids"].view(-1) for s in samples], 0),
            'q_sp_mask':collate_tokens([s["q_sp_codes"]["attention_mask"].view(-1) for s in samples], 0),

            'c1_input_ids': collate_tokens([s["start_para_codes"]["input_ids"] for s in samples], 0),
            'c1_mask': collate_tokens([s["start_para_codes"]["attention_mask"] for s in samples], 0),
                
            'c2_input_ids': collate_tokens([s["bridge_para_codes"]["input_ids"] for s in samples], 0),
            'c2_mask': collate_tokens([s["bridge_para_codes"]["attention_mask"] for s in samples], 0),

            'neg1_input_ids': collate_tokens([s["neg_codes_1"]["input_ids"] for s in samples], 0),
            'neg1_mask': collate_tokens([s["neg_codes_1"]["attention_mask"] for s in samples], 0),
            
            'neg2_input_ids': collate_tokens([s["neg_codes_2"]["input_ids"] for s in samples], 0),
            'neg2_mask': collate_tokens([s["neg_codes_2"]["attention_mask"] for s in samples], 0),
            
        }

    if "token_type_ids" in samples[0]["q_codes"]:
        batch.update({
            'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0),
            'c1_type_ids': collate_tokens([s["start_para_codes"]["token_type_ids"] for s in samples], 0),
            'c2_type_ids': collate_tokens([s["bridge_para_codes"]["token_type_ids"] for s in samples], 0),
            "q_sp_type_ids": collate_tokens([s["q_sp_codes"]["token_type_ids"].view(-1) for s in samples], 0),
            'neg1_type_ids': collate_tokens([s["neg_codes_1"]["token_type_ids"] for s in samples], 0),
            'neg2_type_ids': collate_tokens([s["neg_codes_2"]["token_type_ids"] for s in samples], 0),
        })

    if "sent_ids" in samples[0]["start_para_codes"]:
        batch["c1_sent_target"] = collate_tokens([s["start_para_codes"]["sent_ids"] for s in samples], -1)
        batch["c1_sent_offsets"] = collate_tokens([s["start_para_codes"]["sent_offsets"] for s in samples], 0),

    return batch