in grok/training.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if group["weight_decay"] > 0:
if group["weight_decay_form"] == "honest":
grad = grad + group["weight_decay"] * p.detach()
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group["amsgrad"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["weight_decay_form"] == "to_init":
state["init"] = p.detach().clone()
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["weight_decay"] > 0:
if group["weight_decay_form"] == "to_zero":
p.mul_(1 - group["lr"] * group["weight_decay"])
elif group["weight_decay_form"] == "to_init":
p.add_(
(state["init"] - p) * (group["lr"] * group["weight_decay"])
)
elif group["weight_decay_form"] == "jiggle":
p.mul_(
torch.exp(
torch.randn(1).cuda()
* (group["lr"] * group["weight_decay"])
)
)
elif group["weight_decay_form"] == "honest":
pass
else:
raise ValueError(
f"Invalid weight decay form: {group['weight_decay_form']}"
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group["eps"]
)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
group["eps"]
)
step_size = group["lr"] / bias_correction1
upd = exp_avg / denom
# add uniform gaussian noise to the update
if group["noise_factor"] > 0:
upd += torch.randn_like(upd) * group["noise_factor"]
# if group['noise_factor'] > 0:
# upd *= torch.exp(torch.randn_like(upd) * group['noise_factor'])
p.add_(-step_size * upd)
return loss