in optimum/habana/peft/layer.py [0:0]
def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
"""
Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
The only differences are:
-add reuse cache support.
-add past key value list support
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)
seq_len = q_len
if past_key_value is not None:
if kwargs.get("reuse_cache", False):
seq_len += past_key_value[0][-2]
elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)
rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens
cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)
# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
if len(cos.shape) == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)