in src/peft/tuners/bone/layer.py [0:0]
def get_delta_weight_bone(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.bone_block[adapter].device
dtype = self.bone_block[adapter].dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_bone = self.bone_block[adapter]
if cast_to_fp32:
weight_bone = weight_bone.float()
in_features = orig_weight.size(-1)
r = weight_bone.size(0)
if in_features % r != 0:
last_size = in_features % r
n_block = in_features // r
n_block_size = n_block * r
if re:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_bone)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] - (weight_bone.transpose(0, 1))[:, :last_size]
)
else:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_bone)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] + (weight_bone.transpose(0, 1))[:, :last_size]
)
output_tensor = orig_weight
else:
if re:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_bone
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
else:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_bone
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.bone_block[adapter].data = weight_bone.to(dtype)
return output_tensor