def forward()

in src/peft/tuners/boft/layer.py [0:0]


    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            boft_rotation = torch.eye(
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
                device=x.device,
                dtype=x.dtype,
            )
            boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=x.dtype)

            for active_adapter in self.active_adapters:
                if active_adapter not in self.boft_R.keys():
                    continue
                boft_R = self.boft_R[active_adapter]
                boft_s = self.boft_s[active_adapter].transpose(0, 1)
                dropout = self.boft_dropout[active_adapter]

                N, D, H, _ = boft_R.shape
                boft_R = boft_R.view(N * D, H, H)
                orth_rotate_butterfly = self.cayley_batch(boft_R)
                orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H)
                orth_rotate_butterfly = dropout(orth_rotate_butterfly)
                if self.fbd_cuda_available:
                    block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly)
                else:
                    orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0)
                    block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
                    block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

                boft_P = self.boft_P.to(x)
                block_diagonal_butterfly = block_diagonal_butterfly.to(x)
                butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
                butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
                butterfly_oft_mat = butterfly_oft_mat_batch[0]

                for i in range(1, butterfly_oft_mat_batch.shape[0]):
                    butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat

                boft_rotation = butterfly_oft_mat @ boft_rotation
                boft_scale = boft_s * boft_scale

            x = x.to(self.base_layer.weight.data.dtype)

            orig_weight = self.base_layer.weight.data
            orig_weight = orig_weight.view(
                self.out_features,
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
            )
            orig_weight = torch.transpose(orig_weight, 0, 1)
            rotated_weight = torch.mm(boft_rotation, orig_weight)
            rotated_weight = torch.transpose(rotated_weight, 0, 1)

            scaled_rotated_weight = rotated_weight * boft_scale

            scaled_rotated_weight = scaled_rotated_weight.view(
                self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
            )
            x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
            bias = self._cast_input_dtype(self.base_layer.bias, scaled_rotated_weight.dtype)
            result = F.conv2d(
                input=x,
                weight=scaled_rotated_weight,
                bias=bias,
                padding=self.base_layer.padding[0],
                stride=self.base_layer.stride[0],
            )

        result = result.to(previous_dtype)
        return result