in timm/optim/kron.py [0:0]
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
total_momentum_size = 0
total_momentum_mb = 0
total_precond_size = 0
total_precond_mb = 0
for group in self.param_groups:
mu_dtype = group.get("mu_dtype")
precond_dtype = group.get("precond_dtype", torch.float32)
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
update_prob = group.get("preconditioner_update_probability", None)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
flattened = False
if group['flatten']:
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
flattened = True
if len(state) == 0:
state["step"] = 0
state["update_counter"] = 0
state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
# init Q and einsum expressions on first step
state["Q"], exprs = _init_Q_exprs(
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
)
self._param_exprs[p] = exprs
# Accumulate sizes for log
momentum_size = state["momentum_buffer"].numel()
momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20
total_momentum_size += momentum_size
total_momentum_mb += momentum_mb
precond_size = sum(q.numel() for q in state["Q"])
precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20
total_precond_size += precond_size
total_precond_mb += precond_mb
elif p not in self._param_exprs:
# init only the einsum expressions, called after state load, Q are loaded from state_dict
exprs = _init_Q_exprs(
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
dtype=precond_dtype,
init_q=False,
)
self._param_exprs[p] = exprs
else:
# retrieve cached expressions
exprs = self._param_exprs[p]
# update preconditioners all together deterministically
if update_prob is None:
update_prob = precond_update_prob_schedule
if callable(update_prob):
update_prob = update_prob(state["step"])
state["update_counter"] += 1
do_update = state["update_counter"] >= 1 / update_prob
if do_update:
state["update_counter"] = 0
state["step"] += 1
# Update momentum buffer
beta = group["momentum"]
bias_correction = 1 - beta ** state["step"]
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"])
# Restore momentum dtype
if mu_dtype is not None:
momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype))
debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype)
# Balance preconditioners roughly every 100 updates
balance = self.rng.random() < 0.01 and do_update
if grad.dim() > 1 and balance:
self._balance_Q(state["Q"])
# Update preconditioner
if do_update:
exprA, exprGs, _ = exprs
Q = state["Q"]
if self.deterministic:
torch_rng = torch.Generator(device=debiased_momentum.device)
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
else:
torch_rng = None
V = torch.randn(
debiased_momentum.shape,
generator=torch_rng,
dtype=precond_dtype,
device=debiased_momentum.device,
)
G = debiased_momentum if momentum_into_precond_update else grad
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
terms = self._q_terms(exprGs, A, conjB)
for q, (term1, term2) in zip(Q, terms):
tmp = term1 - term2
tmp *= group["precond_lr"]
if q.dim() < 2:
tmp *= q
tmp /= (term1 + term2).norm(float("inf")) + self._tiny
else:
tmp = torch.triu(tmp)
tmp /= _norm_lower_bound(term1 + term2) + self._tiny
tmp @= q
q.sub_(tmp)
# Precondition gradients
pre_grad = self._precond_grad(
state["Q"],
exprs,
debiased_momentum,
).to(dtype=p.dtype)
# RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if flattened:
pre_grad = pre_grad.view(p.shape)
# Apply weight decay
weight_decay = group["weight_decay"]
if weight_decay != 0:
if group["stochastic_weight_decay"]:
weight_decay = 2 * self.rng.random() * weight_decay
if group["decoupled_decay"]:
if group['corrected_weight_decay']:
wd_scale = group["lr"] ** 2 / self.defaults['lr']
else:
wd_scale = group["lr"]
p.mul_(1. - wd_scale * weight_decay)
else:
pre_grad.add_(p, alpha=weight_decay)
# Update parameters
p.add_(pre_grad, alpha=-group["lr"])
if total_momentum_size > 0:
_logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB")
_logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB")
return loss