in grok/training.py [0:0]
def test_epoch_end(self, outputs):
"""
Used by pytorch_lightning
Accumulates results of all forward validation passes in this epoch
:param outputs: a list of dicts from self.validation_step()
:param batch_idx: which batch this is in the epoch.
:returns: a dict with val_loss, val_accuracy
"""
loss = torch.cat([x["partial_test_loss"] for x in outputs], dim=0) # .sum()
# loss = list([x["partial_test_loss"] for x in outputs]) # .sum()
perplexity = torch.exp(loss)
accuracy = torch.cat([x["partial_test_accuracy"] for x in outputs], dim=0)
logs = {
"test_loss": loss,
"test_accuracy": accuracy,
"test_perplexity": perplexity,
}
return {"test_loss": loss, "log": logs}