in jamba1.5-retriever/scripts/train.py [0:0]
def __call__(self, features):
# Separate input_ids and attention_mask for sentence1 and sentence2
sentence1_features = [{'input_ids': f['input_ids1'], 'attention_mask': f['attention_mask1']} for f in features]
sentence2_features = [{'input_ids': f['input_ids2'], 'attention_mask': f['attention_mask2']} for f in features]
# Call the parent method to handle padding of sentence1 and sentence2
batch_sentence1 = super().__call__(sentence1_features)
batch_sentence2 = super().__call__(sentence2_features)
# Combine sentence1 and sentence2 into a single batch dictionary
batch = {
'input_ids1': batch_sentence1['input_ids'],
'attention_mask1': batch_sentence1['attention_mask'],
'input_ids2': batch_sentence2['input_ids'],
'attention_mask2': batch_sentence2['attention_mask']
}
batch['labels'] = torch.stack([f['labels'] for f in features])
return batch