in src/run_fusion_in_decoder.py [0:0]
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
preds = torch.cat([x['generations'] for x in outputs], dim=0).cpu().numpy().tolist()
labels = torch.cat([x['labels'] for x in outputs], dim=0).cpu().numpy().tolist()
assert len(preds) == self.hparams.data.num_return_sequences * len(labels)
new_labels = []
for label in labels:
new_labels.extend([label] * self.hparams.data.num_return_sequences)
labels = new_labels
assert len(preds) == len(labels)
results = []
for pred, label in zip(preds, labels):
# pdb.set_trace()
pred_str = self.tokenizer.decode(pred)
label_str = self.tokenizer.decode(label)
results.append(
{
"tgt": label_str,
"gen": pred_str,
}
)
if not os.path.exists(self.hparams.data.output_dir):
os.makedirs(self.hparams.data.output_dir)
with jsonlines.open(os.path.join(self.hparams.data.output_dir, f"{self.hparams.model.model_name}_{self.hparams.data.max_source_length}_generated_beamsize_{self.hparams.data.num_beams}_size_{self.hparams.data.num_return_sequences}.jsonl"), "w") as writer:
writer.write_all(results)
tensorboard_logs = {'test_loss': avg_loss}
return {'avg_test_loss': avg_loss, 'progress_bar': tensorboard_logs}