in benchmarks/experimental/experimental_async_approaches.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
amsgrad = group.get("amsgrad", False)
p_data = p.data
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data)
else:
state["exp_avg"] = state["exp_avg"].to(p_data)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
exp_avg_data = exp_avg.data
exp_avg_sq_data = exp_avg_sq.data
# Decay the first and second moment running average coefficient
exp_avg_data.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq_data.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq_data, out=max_exp_avg_sq_data)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq_data.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
if group["weight_decay"] != 0:
p_data.add_(p_data, alpha=-group["weight_decay"] * group["lr"])
p_data.addcdiv_(exp_avg_data, denom, value=-step_size)
return loss