in grok/training.py [0:0]
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad:
self.zero_grad()