def compute_query_states()

in optimum/habana/peft/layer.py [0:0]


def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
    """
    Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
    The only differences are:
    -add reuse cache support.
    -add past key value list support
    """
    hidden_states = kwargs.get("hidden_states")
    position_ids = kwargs.get("position_ids")
    past_key_value = kwargs.get("past_key_value")
    bsz, q_len, _ = hidden_states.size()
    query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

    factor = model.k_proj.in_features // model.k_proj.out_features
    value_states = (
        model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
    )

    seq_len = q_len

    if past_key_value is not None:
        if kwargs.get("reuse_cache", False):
            seq_len += past_key_value[0][-2]
        elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
            # for transformers <= 4.35
            seq_len += past_key_value[0].shape[-2]
        else:
            # since transformers 4.36, this is a DynamicCache instance
            seq_len += past_key_value.get_seq_length(model.layer_idx)

    # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
    if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
        # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
        cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
        return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

    past_seen_tokens = 0
    if position_ids is None:
        # Compute position_ids, since they are required for transformers > 4.37.2
        if past_key_value is None:
            new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
        else:
            past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
            new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
        position_ids = new_cache_positions.unsqueeze(0)

    rotary_emb_kwargs = {"position_ids": position_ids}
    # The `seq_len` argument has been officially removed in transformers >= 4.39.0
    if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
        rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens

    cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)

    # For batched inference unsqueeze it on the correct dim
    # since: https://github.com/huggingface/transformers/pull/29109
    if len(cos.shape) == 3:
        cos = cos.unsqueeze(1)
        sin = sin.unsqueeze(1)

    return (query_states * cos) + (llama_rotate_half(query_states) * sin)