in src/peft/peft_model.py [0:0]
def generate(self, **kwargs):
peft_config = self.active_peft_config
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self._prepare_encoder_decoder_kwargs_for_generation
)
try:
if not peft_config.is_prompt_learning:
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
raise ValueError("input_ids must be provided for Peft model generation")
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
if peft_config.peft_type == PeftType.PREFIX_TUNING:
outputs = self.base_model.generate(**kwargs)
elif peft_config.peft_type in [
PeftType.PROMPT_TUNING,
PeftType.P_TUNING,
PeftType.MULTITASK_PROMPT_TUNING,
]:
kwargs = deepcopy(kwargs)
if "encoder_outputs" in kwargs:
del kwargs["encoder_outputs"]
warnings.warn(
"`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
)
input_ids = kwargs.pop("input_ids")
inputs_embeds = self.word_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
prompts = self.get_prompt(batch_size=batch_size, task_ids=kwargs.pop("task_ids", None))
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
kwargs["inputs_embeds"] = inputs_embeds
if "attention_mask" in kwargs:
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
kwargs["attention_mask"].device
)
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)
return self.base_model.generate(**kwargs)
else:
raise NotImplementedError
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
raise
else:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
return outputs