in grok/training.py [0:0]
def _save_activations(self, outputs: Dict, ds: str) -> None:
"""
Saves activations out to disk for analysis later
:param outputs: a list of tuples from self.training_step()
"""
output: Dict[str, Any] = {}
if self.hparams.save_outputs: # type: ignore
y_hat_rhs = torch.cat([x["y_hat_rhs"] for x in outputs])
output["y_hat_rhs"] = y_hat_rhs
if self.hparams.save_activations: # type: ignore
partial_attentions = list([o["partial_attentions"] for o in outputs])
attentions = self._merge_batch_activations(partial_attentions)
partial_values = list([o["partial_values"] for o in outputs])
values = self._merge_batch_activations(partial_values)
output["attentions"] = attentions
output["values"] = values
if self.hparams.save_outputs or self.hparams.save_activations: # type: ignore
logdir = self.hparams.logdir + "/outputs/" + ds # type: ignore
os.makedirs(logdir, exist_ok=True)
pickle_file = logdir + f"/epoch_{self.current_epoch:010}.pt"
with open(pickle_file, "wb") as fh:
torch.save(output, fh)