in src/data.py [0:0]
def __call__(self, batch):
index = torch.tensor([ex['index'] for ex in batch])
question = [ex['question'] for ex in batch]
question = self.tokenizer.batch_encode_plus(
question,
pad_to_max_length=True,
return_tensors="pt",
max_length=self.question_maxlength,
truncation=True
)
question_ids = question['input_ids']
question_mask = question['attention_mask'].bool()
if batch[0]['scores'] is None or batch[0]['passages'] is None:
return index, question_ids, question_mask, None, None, None
scores = [ex['scores'] for ex in batch]
scores = torch.stack(scores, dim=0)
passages = [ex['passages'] for ex in batch]
passage_ids, passage_masks = encode_passages(
passages,
self.tokenizer,
self.passage_maxlength
)
return (index, question_ids, question_mask, passage_ids, passage_masks, scores)