in src/peft/tuners/hra/layer.py [0:0]
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
new_weight = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
)
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
continue
delta_weight = self.get_delta_weight(active_adapter)
new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)
orig_weight = self.base_layer.weight.data
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)
new_weight = torch.mm(orig_weight, new_weight)
new_weight = new_weight.view(
self.out_features,
self.in_features,
self.base_layer.kernel_size[0],
self.base_layer.kernel_size[0],
)
if self.cast_input_dtype_enabled:
x = self._cast_input_dtype(x, new_weight.dtype)
else:
x = x.to(self.get_base_layer().weight.data.dtype)
result = F.conv2d(
input=x,
weight=new_weight,
bias=bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)
result = result.to(previous_dtype)
return result