def forward()

in src/peft/tuners/hra/layer.py [0:0]


    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            new_weight = torch.eye(
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
                device=x.device,
            )
            for active_adapter in self.active_adapters:
                if active_adapter not in self.hra_u.keys():
                    continue
                delta_weight = self.get_delta_weight(active_adapter)
                new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)

            orig_weight = self.base_layer.weight.data
            orig_weight = orig_weight.view(
                self.out_features,
                self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
            )
            orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
            bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)

            new_weight = torch.mm(orig_weight, new_weight)
            new_weight = new_weight.view(
                self.out_features,
                self.in_features,
                self.base_layer.kernel_size[0],
                self.base_layer.kernel_size[0],
            )

            if self.cast_input_dtype_enabled:
                x = self._cast_input_dtype(x, new_weight.dtype)
            else:
                x = x.to(self.get_base_layer().weight.data.dtype)
            result = F.conv2d(
                input=x,
                weight=new_weight,
                bias=bias,
                padding=self.base_layer.padding[0],
                stride=self.base_layer.stride[0],
            )

        result = result.to(previous_dtype)
        return result