in src/peft/tuners/lora/model.py [0:0]
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
# nothing to do
yield
return
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
# Check that users only passed actually existing adapters.
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
expected_adapters = set()
for layer in self.modules():
if isinstance(layer, LoraLayer):
expected_adapters |= layer.lora_A.keys()
expected_adapters |= layer.lora_embedding_A.keys()
unique_adapters = {name for name in adapter_names if name != "__base__"}
unexpected_adapters = unique_adapters - expected_adapters
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
# deal with beam search
num_beams = kwargs.get("num_beams", None)
uses_beam_search = isinstance(num_beams, int) and (num_beams > 1)
original_adapter_names = adapter_names[:]
if uses_beam_search:
if not isinstance(adapter_names, (list, tuple)):
raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
# When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
# then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
# encoder part. Further below, the original argument is thus restored for the encoder.
adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])
hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
if uses_beam_search and hasattr(self.model, "get_encoder"):
# For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
# the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
for module in self.model.get_encoder().modules():
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
# Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
# trying to exclude the encoder.
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
yield
for handle in hook_handles:
handle.remove()