def test_epoch_end()

in src/run_fusion_in_decoder.py [0:0]


    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        preds  = torch.cat([x['generations'] for x in outputs], dim=0).cpu().numpy().tolist()
        labels = torch.cat([x['labels'] for x in outputs], dim=0).cpu().numpy().tolist()

        assert len(preds) == self.hparams.data.num_return_sequences * len(labels)

        new_labels = []
        for label in labels:
            new_labels.extend([label] * self.hparams.data.num_return_sequences)
        labels = new_labels

        assert len(preds) == len(labels)

        results = []
        for pred, label in zip(preds, labels):
            # pdb.set_trace()
            pred_str = self.tokenizer.decode(pred)
            label_str = self.tokenizer.decode(label)
            results.append(
                {
                    "tgt": label_str,
                    "gen": pred_str,
                }
            )

        if not os.path.exists(self.hparams.data.output_dir):
            os.makedirs(self.hparams.data.output_dir)

        with jsonlines.open(os.path.join(self.hparams.data.output_dir, f"{self.hparams.model.model_name}_{self.hparams.data.max_source_length}_generated_beamsize_{self.hparams.data.num_beams}_size_{self.hparams.data.num_return_sequences}.jsonl"), "w") as writer:
            writer.write_all(results)

        tensorboard_logs = {'test_loss': avg_loss}

        return {'avg_test_loss': avg_loss, 'progress_bar': tensorboard_logs}