def forward()

in src/peft/tuners/lora/bnb.py [0:0]


        def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            self._check_forward_args(x, *args, **kwargs)
            adapter_names = kwargs.pop("adapter_names", None)

            if self.disable_adapters:
                if self.merged:
                    self.unmerge()
                result = self.base_layer(x, *args, **kwargs)
            elif adapter_names is not None:
                result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
            elif self.merged:
                result = self.base_layer(x, *args, **kwargs)
            else:
                result = self.base_layer(x, *args, **kwargs)
                for active_adapter in self.active_adapters:
                    if active_adapter not in self.lora_A.keys():
                        continue
                    lora_A = self.lora_A[active_adapter]
                    lora_B = self.lora_B[active_adapter]
                    dropout = self.lora_dropout[active_adapter]
                    scaling = self.scaling[active_adapter]

                    requires_conversion = not torch.is_autocast_enabled()
                    if requires_conversion:
                        expected_dtype = result.dtype
                        x = self._cast_input_dtype(x, lora_A.weight.dtype)

                    if active_adapter not in self.lora_variant:  # vanilla LoRA
                        output = lora_B(lora_A(dropout(x))) * scaling
                        if requires_conversion:
                            output = output.to(expected_dtype)
                        result = result + output
                    else:
                        result = self.lora_variant[active_adapter].forward(
                            self,
                            active_adapter=active_adapter,
                            x=x,
                            result=result,
                        )
                        if requires_conversion:
                            result = result.to(expected_dtype)

            return result