in src/peft/tuners/hra/layer.py [0:0]
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.hra_u.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.out_features,
self.in_features * base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
delta_weight = self.get_delta_weight(active_adapter)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight.to(orig_dtype)
else:
orig_weight = base_layer.weight.data
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
delta_weight = self.get_delta_weight(active_adapter)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)
base_layer.weight.data = orig_weight.to(orig_dtype)
self.merged_adapters.append(active_adapter)