in grok/training.py [0:0]
def step(self, closure=None):
assert (
closure is not None
), "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(
closure
) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step()