in grok/training.py [0:0]
def validation_epoch_end(self, outputs):
"""
Used by pytorch_lightning
Accumulates results of all forward validation passes in this epoch
:param outputs: a list of dicts from self.validation_step()
:param batch_idx: which batch this is in the epoch.
:returns: a dict with val_loss, val_accuracy
"""
validation_is_real = len(outputs[0]) != 0
if validation_is_real:
self.next_epoch_to_eval = max(
int(1.02 * self.next_epoch_to_eval), self.next_epoch_to_eval + 1
)
loss = torch.stack([x["partial_val_loss"] for x in outputs]).sum()
perplexity = torch.exp(loss)
accuracy = torch.stack([x["partial_val_accuracy"] for x in outputs]).sum()
if self.hparams.save_activations or self.hparams.save_outputs:
if self.current_epoch == 0:
self._save_inputs(outputs, ds="val")
self._save_activations(outputs, ds="val")
logs = {
"val_loss": loss,
"val_accuracy": accuracy,
"val_perplexity": perplexity,
}
for name, param in self.named_parameters():
# n parameters
n_params = param.numel()
# get the l2 norm of the parameter
logs["paramnorm_" + name] = torch.norm(
param, 2
).detach().cpu().numpy() / np.sqrt(n_params)
# train accuracy
device = self.transformer.embedding.weight.device
train_data = self.train_dataset.data.to(device)
training_data = {"text": train_data[:, :-1], "target": train_data[:, 1:]}
with torch.no_grad():
tr_loss, tr_acc, *_ = self._step(training_data, 0)
logs["full_train_loss"] = tr_loss
logs["full_train_acc"] = tr_acc
for k, v in logs.items():
self.log(k, v)
# save a checkpoint if the epoch is a power of 2
if (
self.current_epoch > 0
and int(2 ** (int(np.log(self.current_epoch) / np.log(2))))
== self.current_epoch
):
self.trainer.save_checkpoint(
os.path.join(
self.hparams.checkpoint_path,
"epoch_" + str(self.current_epoch) + ".ckpt",
)
)
if validation_is_real:
return logs