in src/peft/tuners/xlora/model.py [0:0]
def _enable_peft_forward_hooks(self, *generate_args, **generate_kwargs):
def scalings_injection_hook(target, args, kwargs, scalings):
# pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
kwargs["scalings"] = scalings
return args, kwargs
handles_to_remove = None
def pre_forward(module, *args, **kwargs):
nonlocal handles_to_remove
# =========================== Forward pass with "dummy" scalings ==================
args_real = args[0]
kwargs_real = args[1]
kwargs_real.update(kwargs)
dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real)
hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
with torch.no_grad():
self.lora_model.disable_adapter_layers()
try:
scaling_pass_kwargs = kwargs_real.copy()
scaling_pass_kwargs["output_hidden_states"] = True
scaling_pass_kwargs["return_dict"] = True
try:
base_output = self.lora_model.model.forward(*args_real, **scaling_pass_kwargs)
finally:
# Clean everything up
for handle in hook_handles:
handle.remove()
finally:
self.lora_model.enable_adapter_layers()
xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)
# =========================== Real forward pass with calculated scalings ==================
hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(scalings_injection_hook, scalings=xlora_scalings)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
handles_to_remove = hook_handles
if not self.disabled:
forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True)
# Run the forward pass: first the scaling pass in the hook, and then with the base model
yield
if not self.disabled:
# TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks.
for handle in handles_to_remove:
handle.remove()
forward_handle.remove()