def generate_answers()

in paq/generation/filtering/filterer.py [0:0]


    def generate_answers(self, examples):

        eval_dataset, eval_dataloader = self._get_dataloader_for_examples(examples)
        total = 0
        exactmatch = []
        with torch.no_grad():
            for i, batch in enumerate(eval_dataloader):
                (idx, _, _, context_ids, context_mask) = batch

                outputs = self.model.generate(
                    input_ids=context_ids.to(self.device),
                    attention_mask=context_mask.to(self.device),
                    max_length=10,
                )
                for k, o in enumerate(outputs):
                    ans = self.tokenizer.decode(o, skip_special_tokens=True)
                    example = eval_dataset.data[idx[k]]
                    score = src.evaluation.ems(ans, example['answers'])
                    exactmatch.append(score)
                    example['consistent'] = score
                    example['filter_answer'] = ans
                    total += 1

                if (i + 1) % 10 == 0:
                    logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
        logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
        output = _get_reader_output_format(examples)
        return output