def __call__()

in jamba1.5-retriever/scripts/train.py [0:0]


    def __call__(self, features):
        # Separate input_ids and attention_mask for sentence1 and sentence2
        sentence1_features = [{'input_ids': f['input_ids1'], 'attention_mask': f['attention_mask1']} for f in features]
        sentence2_features = [{'input_ids': f['input_ids2'], 'attention_mask': f['attention_mask2']} for f in features]

        # Call the parent method to handle padding of sentence1 and sentence2
        batch_sentence1 = super().__call__(sentence1_features)
        batch_sentence2 = super().__call__(sentence2_features)

        # Combine sentence1 and sentence2 into a single batch dictionary
        batch = {
            'input_ids1': batch_sentence1['input_ids'],
            'attention_mask1': batch_sentence1['attention_mask'],
            'input_ids2': batch_sentence2['input_ids'],
            'attention_mask2': batch_sentence2['attention_mask']
        }

        batch['labels'] = torch.stack([f['labels'] for f in features])

        return batch