in src/data.py [0:0]
def __call__(self, batch):
assert(batch[0]['target'] != None)
index = torch.tensor([ex['index'] for ex in batch])
target = [ex['target'] for ex in batch]
target = self.tokenizer.batch_encode_plus(
target,
max_length=self.answer_maxlength if self.answer_maxlength > 0 else None,
pad_to_max_length=True,
return_tensors='pt',
truncation=True if self.answer_maxlength > 0 else False,
)
target_ids = target["input_ids"]
target_mask = target["attention_mask"].bool()
target_ids = target_ids.masked_fill(~target_mask, -100)
def append_question(example):
if example['passages'] is None:
return [example['question']]
return [example['question'] + " " + t for t in example['passages']]
text_passages = [append_question(example) for example in batch]
passage_ids, passage_masks = encode_passages(text_passages,
self.tokenizer,
self.text_maxlength)
return (index, target_ids, target_mask, passage_ids, passage_masks)