in hype_kg/codes/optimizers/radam.py [0:0]
def step(self, manifold, closure=None):
"""Performs a single optimization step.
Arguments
---------
closure : callable (optional)
A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
with torch.no_grad():
for group in self.param_groups:
if "step" not in group:
group["step"] = 0
betas = group["betas"]
weight_decay = group["weight_decay"]
eps = group["eps"]
learning_rate = group["lr"]
amsgrad = group["amsgrad"]
for point in group["params"]:
grad = point.grad
if grad is None:
continue
if isinstance(point, (ManifoldParameter)):
manifold = point.manifold
c = point.manifold.c
manifold = manifold
c = manifold.c
if grad.is_sparse:
raise RuntimeError(
"Riemannian Adam does not support sparse gradients yet. Did not need in my experiments."
)
state = self.state[point]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(point)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(point)
# make local variables for easy access
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
# actual step
grad.add_(weight_decay, point)
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
exp_avg_sq.mul_(betas[1]).add_(
1 - betas[1], manifold.inner(point, grad, keepdim=True)
)
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(eps)
else:
denom = exp_avg_sq.sqrt().add_(eps)
group["step"] += 1
bias_correction1 = 1 - betas[0] ** group["step"]
bias_correction2 = 1 - betas[1] ** group["step"]
step_size = (
learning_rate * bias_correction2 ** 0.5 / bias_correction1
)
# copy the state, we need it for retraction
# get the direction for ascend
direction = exp_avg / denom
# transport the exponential averaging to the new point
new_point = manifold.proj(manifold.expmap(-step_size * direction, point))
exp_avg_new = manifold.ptransp(point, new_point, exp_avg)
# use copy only for user facing point
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)
group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
self.stabilize_group(group)
return loss