in higher/optim.py [0:0]
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if g.is_sparse:
raise RuntimeError('ASGD does not support sparse gradients')
state = self.state[group_idx][p_idx]
# State initialization
if len(state) == 0:
state['step'] = 0
state['eta'] = group['lr']
state['mu'] = 1
state['ax'] = _torch.zeros_like(p.data)
state['step'] += 1
if group['weight_decay'] != 0:
g = _add(g, group['weight_decay'], p)
# decay term
p = p.mul(1 - group['lambd'] * state['eta'])
# update parameter
group['params'][p_idx] = _add(p, -state['eta'], g)
# averaging
if state['mu'] != 1:
state['ax'] = _add(
state['ax'],
p.sub(state['ax']).mul(state['mu'])
)
else:
state['ax'] = p
# update eta and mu
state['eta'] = (
group['lr'] / _math.pow(
(1 + group['lambd'] * group['lr'] * state['step']),
group['alpha']
)
)
state['mu'] = 1 / max(1, state['step'] - group['t0'])