in src/peft/tuners/boft/layer.py [0:0]
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
dtype=x.dtype,
)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=x.dtype)
for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
continue
boft_R = self.boft_R[active_adapter]
boft_s = self.boft_s[active_adapter].transpose(0, 1)
dropout = self.boft_dropout[active_adapter]
N, D, H, _ = boft_R.shape
boft_R = boft_R.view(N * D, H, H)
orth_rotate_butterfly = self.cayley_batch(boft_R)
orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H)
orth_rotate_butterfly = dropout(orth_rotate_butterfly)
if self.fbd_cuda_available:
block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly)
else:
orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0)
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)
boft_P = self.boft_P.to(x)
block_diagonal_butterfly = block_diagonal_butterfly.to(x)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]
for i in range(1, butterfly_oft_mat_batch.shape[0]):
butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat
boft_rotation = butterfly_oft_mat @ boft_rotation
boft_scale = boft_s * boft_scale
x = x.to(self.base_layer.weight.data.dtype)
orig_weight = self.base_layer.weight.data
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
orig_weight = torch.transpose(orig_weight, 0, 1)
rotated_weight = torch.mm(boft_rotation, orig_weight)
rotated_weight = torch.transpose(rotated_weight, 0, 1)
scaled_rotated_weight = rotated_weight * boft_scale
scaled_rotated_weight = scaled_rotated_weight.view(
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
)
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
result = F.conv2d(
input=x,
weight=scaled_rotated_weight,
bias=bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)
result = result.to(previous_dtype)
return result