in src/peft/peft_model.py [0:0]
def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
peft_config = self.active_peft_config
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
# https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
# for some architectures which requires a special fix for prompt tuning etc.
# TODO: starting with transformers 4.38, all architectures should support caching.
uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0")
uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
if packaging.version.parse(transformers.__version__) > packaging.version.parse("4.43.3"):
# https://github.com/huggingface/transformers/pull/31445
transformers_new_cache_archs.append("bloom")
uses_cache = uses_transformers_4_38 or (
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
)
# heuristic to determine if we're in 'prefill stage' (when the KV cache is filled with the values from the
# initial input)
is_prefill = (model_kwargs.get("cache_position") is not None) and (model_kwargs["cache_position"][0] == 0)
if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if peft_config.is_prompt_learning:
if uses_cache and (model_kwargs.get("past_key_values", None) is not None):
# change in the logic of `prepare_inputs_for_generation` makes the below code necessary
# In prompt learning methods, past key values are longer when compared to the `input_ids`.
# As such only consider the last input ids in the autogressive generation phase.
past_key_values = model_kwargs["past_key_values"]
if isinstance(past_key_values, (tuple, list)):
seq_len = past_key_values[0][0].shape[-2]
else: # using transformers kv cache
seq_len = past_key_values.get_seq_length()
if seq_len >= model_kwargs["input_ids"].shape[1]:
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]
if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
if isinstance(attention_mask, dict):
# see: https://github.com/huggingface/transformers/pull/37866
# For now, just deal with the case of a single attention mask
if len(attention_mask) != 1:
raise ValueError(
f"Expected a single attention mask, got {len(attention_mask)} instead, please open an "
"issue (https://github.com/huggingface/peft/issues) and report the error."
)
attention_mask = list(attention_mask.values())[0]
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
if attention_mask.dim() == 4:
# Transform the 4d attention mask to 2d, leave it up to the model to deal with it instead of trying
# to create a 4d attention mask here.
# from [batch_size, heads, input_ids_length, total_sequence_length]
# to [batch_size, total_sequence_length]
bs = attention_mask.shape[0]
total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2]
attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
if is_prefill and (peft_config.peft_type != PeftType.PREFIX_TUNING):
# if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
# (embeddings) are inserted, thus set cache_position to correspond to these tokens
cache_position_ = torch.arange(total_seq_len, device=model_kwargs["input_ids"].device)
else:
# prefix tuning acts directly on the cache, no need to upate cache_position
cache_position_ = model_kwargs["cache_position"]
attention_mask_new = create_attention_mask(
self.get_base_model(),
model_input=None,
attention_mask=attention_mask_2d,
past_key_values=model_kwargs.get("past_key_values"),
cache_position=cache_position_,
batch_size=bs,
sequence_length=total_seq_len,
)
model_kwargs["attention_mask"] = attention_mask_new
else:
# 2d attention mask
model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if model_kwargs.get("position_ids", None) is not None:
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
model_kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
# no past_key_values or past_key_values empty cache
requires_prompt_injection = (model_kwargs.get("past_key_values", None) is None) or (
isinstance(model_kwargs["past_key_values"], transformers.Cache)
and not model_kwargs["past_key_values"].get_seq_length()
)
if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING:
# some archs require max_cache_len to re-initialize the cache
max_cache_len = getattr(model_kwargs.get("past_key_values", None), "max_cache_len", None)
new_past_key_values = self.get_prompt(
batch_size=model_kwargs["input_ids"].shape[0],
max_cache_len=max_cache_len,
)
model_kwargs["past_key_values"] = new_past_key_values
elif requires_prompt_injection:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None
# if we're in the prefill stage
if is_prefill and (peft_config.peft_type == PeftType.PREFIX_TUNING):
# for prefix tuning, the past_key_values have been prefilled
model_kwargs["cache_position"] += peft_config.num_virtual_tokens
elif peft_config.peft_type != PeftType.PREFIX_TUNING: # prefix tuning needs cache_position
# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is passed in the forward
# pass to keep track of the position ids of the cache. We have to pop that from `model_kwargs` as
# `cache_position` is properly created by the model, using the passed `inputs_embeds`:
# https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
_ = model_kwargs.pop("cache_position", None)
return model_kwargs