in timm/optim/adafactor_bv.py [0:0]
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avg_sq_rs = []
exp_avg_sq_cs = []
exp_avg_sqs = []
state_steps = []
exp_avgs = [] # For momentum
for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Sparse gradients not supported")
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
# NOTE step on CPU, probably need some more though to make capturable
state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
shape = p.grad.shape
factored_dims = _factored_dims(
shape,
factored=True,
min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
)
if factored_dims is not None:
dc, dr = factored_dims
row_shape = list(p.grad.shape)
row_shape[dr] = 1
col_shape = list(p.grad.shape)
col_shape[dc] = 1
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
else:
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
if self.defaults['momentum'] is not None:
state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
state_steps.append(state['step'])
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
exp_avg_sqs.append(state.get('exp_avg_sq', None))
exp_avgs.append(state.get('exp_avg', None))
if group['foreach']:
func = _multi_tensor_adafactor
else:
func = _single_tensor_adafactor
func(
params=params_with_grad,
grads=grads,
exp_avg_sq_rs=exp_avg_sq_rs,
exp_avg_sq_cs=exp_avg_sq_cs,
exp_avg_sqs=exp_avg_sqs,
exp_avgs=exp_avgs,
state_steps=state_steps,
beta2_decay=group['decay_rate'],
beta2_cap=group['beta2_cap'],
min_dim_size_to_factor=group['min_dim_size_to_factor'],
eps=group['eps'],
lr=group['lr'],
weight_decay=group['weight_decay'],
momentum=group['momentum'],
momentum_dtype=group['momentum_dtype'],
clipping_threshold=group['clipping_threshold'],
unscaled_wd=group['unscaled_wd'],
caution=group['caution'],
max_lr=self.defaults['lr'] if group['corrected_weight_decay'] else None,
)
return loss