in paq/generation/filtering/filterer.py [0:0]
def generate_answers(self, examples):
eval_dataset, eval_dataloader = self._get_dataloader_for_examples(examples)
total = 0
exactmatch = []
with torch.no_grad():
for i, batch in enumerate(eval_dataloader):
(idx, _, _, context_ids, context_mask) = batch
outputs = self.model.generate(
input_ids=context_ids.to(self.device),
attention_mask=context_mask.to(self.device),
max_length=10,
)
for k, o in enumerate(outputs):
ans = self.tokenizer.decode(o, skip_special_tokens=True)
example = eval_dataset.data[idx[k]]
score = src.evaluation.ems(ans, example['answers'])
exactmatch.append(score)
example['consistent'] = score
example['filter_answer'] = ans
total += 1
if (i + 1) % 10 == 0:
logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
output = _get_reader_output_format(examples)
return output