in src/peft/optimizers/lorafa.py [0:0]
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
scaling_factor = group["scaling_factor"]
param_list = []
name_list = []
for p, n in zip(group["params"], group["names"]):
# Skip non-lora no-grad module, since we need lora_A which is no-grad.
if "lora" not in n and p.grad is None:
continue
grad = p.grad
if "lora" in n:
param_list.append(p)
name_list.append(n)
if len(param_list) == 2:
name = n[: n.find("lora")] + "lora"
elif len(param_list) == 1:
continue
else:
name = n
# param_list contains a pair of A and B adapters
# i.e., param_list -> [A,B]
state = self.state[name]
# State initialization
if len(state) == 0:
if len(param_list) == 2:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg_B"] = torch.zeros_like(param_list[1])
# Exponential moving average of squared gradient values
state["exp_avg_sq_B"] = torch.zeros_like(param_list[1])
else:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)
# Below is the LoRA-FA part
# 1. In this part, we optimize the gradient of B as:
# g^B = \left(\frac{r}{\alpha}\right)^2 (A^\top A)^{-1} g_{\text{LoRA-FA}}^B
# to min the func as described below:
# \min_{g^B} \|\hat{g}_\text{LoRA-FA} - g\|_F^2
# 2. After the gradient of B is ready, update the optimizer state
if len(param_list) == 2:
A = param_list[0]
B = param_list[1]
grad_B_orin = B.grad
# projection
delta = 1e-8
# computing the inverse matrix
AA_T = A @ A.T
AA_T_inv = torch.linalg.pinv(AA_T + delta * torch.eye(A.shape[0]).to(A.device))
device_type = infer_device()
if is_bf16_available():
with autocast(device_type=device_type, dtype=torch.bfloat16):
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv)
else:
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv)
if grad_B.dtype != B.grad.dtype:
grad_B = grad_B.to(B.grad.dtype)
exp_avg_B, exp_avg_sq_B = state["exp_avg_B"], state["exp_avg_sq_B"]
beta1, beta2 = group["betas"]
state["step"] += 1
exp_avg_B.mul_(beta1).add_(grad_B, alpha=(1.0 - beta1))
exp_avg_sq_B.mul_(beta2).addcmul_(grad_B, grad_B, value=1.0 - beta2)
denom_B = exp_avg_sq_B.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
B.addcdiv_(exp_avg_B, denom_B, value=-step_size)
if group["weight_decay"] > 0.0:
B.add_(B, alpha=(-group["lr"] * group["weight_decay"]))
param_list = []
name_list = []
# Below is the original AdamW
else:
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss