def llama_compute_query_states()

in src/peft/tuners/adaption_prompt/utils.py [0:0]


def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
    """
    Compute query states for Llama models specifically. They need to be recomputed as the forward() method of the
    original LlamaModel in the transformers library does not return them. See the related discussion in the PR:
    https://github.com/huggingface/peft/pull/268
    """
    hidden_states = kwargs.get("hidden_states")
    position_ids = kwargs.get("position_ids")
    past_key_value = kwargs.get("past_key_value")
    bsz, q_len, _ = hidden_states.size()
    if hasattr(model, "num_heads"):
        # TODO: remove this clause after 2026-01-01
        num_heads = model.num_heads
    else:  # changed in https://github.com/huggingface/transformers/pull/35235
        num_heads = model.config.num_attention_heads
    query_states = model.q_proj(hidden_states).view(bsz, q_len, num_heads, model.head_dim).transpose(1, 2)

    factor = model.k_proj.in_features // model.k_proj.out_features
    value_states = model.v_proj(hidden_states).view(bsz, q_len, (num_heads // factor), model.head_dim).transpose(1, 2)

    seq_len = q_len

    if past_key_value is not None:
        if isinstance(past_key_value, tuple):
            # for transformers <= 4.35
            seq_len += past_key_value[0].shape[-2]
        else:
            # since transformers 4.36, this is a DynamicCache instance
            seq_len += past_key_value.get_seq_length(model.layer_idx)

    # model.rotary_emb is deprecated and will be removed in transformers > 4.47.0. Instead, the position embeddings are
    # passed via the kwargs
    if "position_embeddings" in kwargs:
        cos, sin = kwargs["position_embeddings"]
        cos = cos.unsqueeze(1)
        sin = sin.unsqueeze(1)
        return (query_states * cos) + (llama_rotate_half(query_states) * sin)

    # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
    if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
        # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
        cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
        return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

    past_seen_tokens = 0
    if position_ids is None:
        # Compute position_ids, since they are required for transformers > 4.37.2
        if past_key_value is None:
            new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
        else:
            past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
            new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
        position_ids = new_cache_positions.unsqueeze(0)

    rotary_emb_kwargs = {"position_ids": position_ids}
    # The `seq_len` argument has been officially removed in transformers >= 4.39.0
    if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
        rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens

    cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)

    # For batched inference unsqueeze it on the correct dim
    # since: https://github.com/huggingface/transformers/pull/29109
    if len(cos.shape) == 3:
        cos = cos.unsqueeze(1)
        sin = sin.unsqueeze(1)

    return (query_states * cos) + (llama_rotate_half(query_states) * sin)