def _enable_peft_forward_hooks()

in src/peft/tuners/lora/model.py [0:0]


    def _enable_peft_forward_hooks(self, *args, **kwargs):
        # If adapter_names is passed as an argument, we inject it into the forward arguments.
        adapter_names = kwargs.pop("adapter_names", None)
        if adapter_names is None:
            # nothing to do
            yield
            return

        if self.training:
            raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")

        # Check that users only passed actually existing adapters.
        # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
        # to check that there is at least one layer with the given name, or else something like typos can easily slip.
        expected_adapters = set()
        for layer in self.modules():
            if isinstance(layer, LoraLayer):
                expected_adapters |= layer.lora_A.keys()
                expected_adapters |= layer.lora_embedding_A.keys()
        unique_adapters = {name for name in adapter_names if name != "__base__"}
        unexpected_adapters = unique_adapters - expected_adapters
        if unexpected_adapters:
            raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")

        # deal with beam search
        num_beams = kwargs.get("num_beams", None)
        uses_beam_search = isinstance(num_beams, int) and (num_beams > 1)
        original_adapter_names = adapter_names[:]
        if uses_beam_search:
            if not isinstance(adapter_names, (list, tuple)):
                raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
            # When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
            # then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
            # encoder part. Further below, the original argument is thus restored for the encoder.
            adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])

        hook_handles = []
        for module in self.modules():
            if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
                pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
                handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
                hook_handles.append(handle)

        if uses_beam_search and hasattr(self.model, "get_encoder"):
            # For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
            # the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
            for module in self.model.get_encoder().modules():
                if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
                    # Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
                    # trying to exclude the encoder.
                    pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
                    handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
                    hook_handles.append(handle)

        yield

        for handle in hook_handles:
            handle.remove()