in grok/training.py [0:0]
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][
0
].device # put everything on the same device, in case of model parallelism
grad_norms = [
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
print("grad norms is ", grad_norms, "!" * 1000)
norm = torch.norm(
torch.stack(grad_norms),
p=2,
)
return norm