in grok/training.py [0:0]
def training_epoch_end(self, outputs):
"""
Used by pytorch_lightning
Accumulates results of all forward training passes in this epoch
:param outputs: a list of dicts from self.training_step()
:param batch_idx: which batch this is in the epoch.
:returns: a dict with loss, accuracy, lr, probabilities of solutions,
attentions, and values
"""
epoch_is_to_be_logged = self.current_epoch == self.next_train_epoch_to_log
if epoch_is_to_be_logged:
self.next_train_epoch_to_log = max(
int(1.01 * self.next_train_epoch_to_log),
self.next_train_epoch_to_log + 1,
)
with torch.no_grad():
try:
loss = torch.stack([x["partial_train_loss"] for x in outputs]).sum()
except Exception as e:
print("!" * 80)
print(outputs)
raise e
perplexity = torch.exp(loss)
accuracy = torch.stack(
[x["partial_train_accuracy"] for x in outputs]
).sum()
# avg_lr = torch.stack([x["learning_rate"] for x in outputs]).mean()
# max_lr = torch.stack([x["learning_rate"] for x in outputs]).max()
# last_lr = outputs[-1]["learning_rate"]
first_lr = outputs[0]["learning_rate"]
if self.hparams.save_activations or self.hparams.save_outputs:
if self.current_epoch == 0:
self._save_inputs(outputs, ds="train")
self._save_activations(outputs, ds="train")
logs = {
"train_loss": loss,
"train_accuracy": accuracy,
"train_perplexity": perplexity,
"learning_rate": first_lr,
"len_train_ds": len(self.train_dataset),
"len_val_ds": len(self.val_dataset),
"batches_per_epoch": self.batches_per_epoch,
"time_per_epoch": time.time() - self.training_epoch_start_time,
"fwd_time_in_epoch": self.fwd_time_in_epoch,
}
for k, v in logs.items():
self.log(k, v)