def prepare_inputs_for_generation()

in src/peft/peft_model.py [0:0]


    def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)

        # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
        # for some architectures which requires a special fix for prompt tuning etc.
        # TODO: starting with transformers 4.38, all architectures should support caching.
        uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0")
        uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
        transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
        if packaging.version.parse(transformers.__version__) > packaging.version.parse("4.43.3"):
            # https://github.com/huggingface/transformers/pull/31445
            transformers_new_cache_archs.append("bloom")

        uses_cache = uses_transformers_4_38 or (
            uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
        )

        # heuristic to determine if we're in 'prefill stage' (when the KV cache is filled with the values from the
        # initial input)
        is_prefill = (model_kwargs.get("cache_position") is not None) and (model_kwargs["cache_position"][0] == 0)

        if peft_config.peft_type == PeftType.POLY:
            model_kwargs["task_ids"] = task_ids
        if peft_config.is_prompt_learning:
            if uses_cache and (model_kwargs.get("past_key_values", None) is not None):
                # change in the logic of `prepare_inputs_for_generation` makes the below code necessary
                # In prompt learning methods, past key values are longer when compared to the `input_ids`.
                # As such only consider the last input ids in the autogressive generation phase.
                past_key_values = model_kwargs["past_key_values"]
                if isinstance(past_key_values, (tuple, list)):
                    seq_len = past_key_values[0][0].shape[-2]
                else:  # using transformers kv cache
                    seq_len = past_key_values.get_seq_length()
                if seq_len >= model_kwargs["input_ids"].shape[1]:
                    model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]

            if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
                if isinstance(attention_mask, dict):
                    # see: https://github.com/huggingface/transformers/pull/37866
                    # For now, just deal with the case of a single attention mask
                    if len(attention_mask) != 1:
                        raise ValueError(
                            f"Expected a single attention mask, got {len(attention_mask)} instead, please open an "
                            "issue (https://github.com/huggingface/peft/issues) and report the error."
                        )
                    attention_mask = list(attention_mask.values())[0]

                size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
                prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
                if attention_mask.dim() == 4:
                    # Transform the 4d attention mask to 2d, leave it up to the model to deal with it instead of trying
                    # to create a 4d attention mask here.
                    # from [batch_size, heads, input_ids_length, total_sequence_length]
                    # to   [batch_size, total_sequence_length]
                    bs = attention_mask.shape[0]
                    total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2]
                    attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)

                    if is_prefill and (peft_config.peft_type != PeftType.PREFIX_TUNING):
                        # if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
                        # (embeddings) are inserted, thus set cache_position to correspond to these tokens
                        cache_position_ = torch.arange(total_seq_len, device=model_kwargs["input_ids"].device)
                    else:
                        # prefix tuning acts directly on the cache, no need to upate cache_position
                        cache_position_ = model_kwargs["cache_position"]

                    attention_mask_new = create_attention_mask(
                        self.get_base_model(),
                        model_input=None,
                        attention_mask=attention_mask_2d,
                        past_key_values=model_kwargs.get("past_key_values"),
                        cache_position=cache_position_,
                        batch_size=bs,
                        sequence_length=total_seq_len,
                    )
                    model_kwargs["attention_mask"] = attention_mask_new
                else:
                    # 2d attention mask
                    model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            if model_kwargs.get("position_ids", None) is not None:
                warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
                model_kwargs["position_ids"] = None

            if kwargs.get("token_type_ids", None) is not None:
                warnings.warn(
                    "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                )
                kwargs["token_type_ids"] = None

            # no past_key_values or past_key_values empty cache
            requires_prompt_injection = (model_kwargs.get("past_key_values", None) is None) or (
                isinstance(model_kwargs["past_key_values"], transformers.Cache)
                and not model_kwargs["past_key_values"].get_seq_length()
            )

            if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING:
                # some archs require max_cache_len to re-initialize the cache
                max_cache_len = getattr(model_kwargs.get("past_key_values", None), "max_cache_len", None)
                new_past_key_values = self.get_prompt(
                    batch_size=model_kwargs["input_ids"].shape[0],
                    max_cache_len=max_cache_len,
                )
                model_kwargs["past_key_values"] = new_past_key_values
            elif requires_prompt_injection:
                inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
                prompts = prompts.to(inputs_embeds.dtype)
                model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
                model_kwargs["input_ids"] = None

        # if we're in the prefill stage
        if is_prefill and (peft_config.peft_type == PeftType.PREFIX_TUNING):
            # for prefix tuning, the past_key_values have been prefilled
            model_kwargs["cache_position"] += peft_config.num_virtual_tokens
        elif peft_config.peft_type != PeftType.PREFIX_TUNING:  # prefix tuning needs cache_position
            # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is passed in the forward
            # pass to keep track of the position ids of the cache. We have to pop that from `model_kwargs` as
            # `cache_position` is properly created by the model, using the passed `inputs_embeds`:
            # https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
            _ = model_kwargs.pop("cache_position", None)

        return model_kwargs