def _enable_peft_forward_hooks()

in src/peft/tuners/xlora/model.py [0:0]


    def _enable_peft_forward_hooks(self, *generate_args, **generate_kwargs):
        def scalings_injection_hook(target, args, kwargs, scalings):
            # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
            kwargs["scalings"] = scalings
            return args, kwargs

        handles_to_remove = None

        def pre_forward(module, *args, **kwargs):
            nonlocal handles_to_remove

            # =========================== Forward pass with "dummy" scalings ==================

            args_real = args[0]
            kwargs_real = args[1]
            kwargs_real.update(kwargs)

            dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real)

            hook_handles = []
            for module in self.modules():
                if isinstance(module, LoraLayer):
                    pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings)
                    handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
                    hook_handles.append(handle)

            with torch.no_grad():
                self.lora_model.disable_adapter_layers()

                try:
                    scaling_pass_kwargs = kwargs_real.copy()
                    scaling_pass_kwargs["output_hidden_states"] = True
                    scaling_pass_kwargs["return_dict"] = True
                    try:
                        base_output = self.lora_model.model.forward(*args_real, **scaling_pass_kwargs)
                    finally:
                        # Clean everything up
                        for handle in hook_handles:
                            handle.remove()
                finally:
                    self.lora_model.enable_adapter_layers()

            xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)

            # =========================== Real forward pass with calculated scalings ==================

            hook_handles = []
            for module in self.modules():
                if isinstance(module, LoraLayer):
                    pre_forward = partial(scalings_injection_hook, scalings=xlora_scalings)
                    handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
                    hook_handles.append(handle)

            handles_to_remove = hook_handles

        if not self.disabled:
            forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True)

        # Run the forward pass: first the scaling pass in the hook, and then with the base model
        yield

        if not self.disabled:
            # TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks.
            for handle in handles_to_remove:
                handle.remove()
            forward_handle.remove()