in src/peft/tuners/lora/layer.py [0:0]
def olora_init(self, adapter_name):
base_layer = self.get_base_layer()
orig_weight = base_layer.weight
bnb_param_type = get_bnb_param_type(orig_weight)
dtype = orig_weight.dtype
if bnb_param_type:
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
weight_tensor = dequantize_module_weight(base_layer)
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = orig_weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")
scale_factor = self.scaling[adapter_name]
r = self.r[adapter_name]
weight_tensor = weight_tensor.to(torch.float32)
Q, R = torch.linalg.qr(weight_tensor.data)
Qr, Rr = Q[:, :r], R[:r]
self.lora_A[adapter_name].weight.data = Rr.contiguous()
self.lora_B[adapter_name].weight.data = Qr.contiguous()
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
quant_type=orig_weight.quant_type,
quant_storage=orig_weight.quant_storage,
compress_statistics=orig_weight.compress_statistics,
module=orig_weight.module,
).to(orig_weight.device)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
requires_grad=orig_weight.requires_grad,
has_fp16_weights=orig_weight.has_fp16_weights,
).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor