in fairscale/optim/adam.py [0:0]
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
for i in range(len(self.param_groups)):
group = self.param_groups[i]
bias_correction = 1 if group["bias_correction"] else 0
tensorlists: Dict[torch.device, List[List[torch.Tensor]]] = dict()
for j in range(len(group["params"])):
p = group["params"][j]
# note: p.grad should not ever be set for correct
# operation of mixed precision optimizer that sometimes
# sends None gradients
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"FusedAdam does not support sparse gradients, " "please consider SparseAdam instead"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
out_p = p.data if self.mixed_precision else torch.tensor([])
param = self.fp32_param_groups[i]["params"][j] if self.mixed_precision else p
scale = 1.0
if self.mixed_precision:
pl = [param.data, exp_avg, exp_avg_sq, grad, out_p]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
else:
pl = [param.data, exp_avg, exp_avg_sq, grad]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=list(tensorlists.keys())[0])
per_device_found_inf = _MultiDeviceReplicator(found_inf)
for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
fused_adam_cuda.adam(
2048 * 32,
self._overflow_buf,
tensorlist,
group["lr"],
beta1,
beta2,
group["eps"],
scale,
self._optim_scale,
per_device_found_inf.get(tensordevice),
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)
if sum(v.item() for v in per_device_found_inf._per_device_tensors.values()):
self._steps_since_optim_scale_change = 0
self._optim_scale /= 2
if self._optim_scale < 1.0:
raise RuntimeError("Optimizer state scale < 1. This may mean that gradients are exploding")
for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
self.state[p]["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
else:
self._steps_since_optim_scale_change += 1
if self._steps_since_optim_scale_change == self._optim_scale_update_freq:
self._steps_since_optim_scale_change = 0
if self._optim_scale < 2 ** 16:
self._optim_scale *= 2
return loss