in src/peft/tuners/lora/bnb.py [0:0]
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.get_base_layer().weight
kwargs = weight.__dict__
output = dequantize_bnb_weight(weight, state=weight.quant_state)
if active_adapter not in self.lora_variant: # vanilla LoRA
lora_data = self.get_delta_weight(active_adapter)
w_data = output + lora_data
else:
w_data = self.lora_variant[active_adapter].merge_safe(self, active_adapter, output)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
# torch.compile can introduce attributes preceded by '_', remove them
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data + self.lora_B[active_adapter].bias
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().bias.data = bias_data
self.merged_adapters.append(active_adapter)