in src/peft/tuners/adaption_prompt/layer.py [0:0]
def forward(self, **kwargs):
"""
Forward pass for the adapter which wraps the original LlamaAttention module.
"Official" paper implementation:
https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L141
Args:
kwargs: See the original LlamaAttention module.
"""
if kwargs.get("output_attention", False):
raise NotImplementedError("output_attention is not currently supported.")
output, *_ = self.model(**kwargs)
bsz = output.shape[0]
q_len = output.shape[1]
embed_dim = output.shape[2]
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers
if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
if hasattr(self.model, "num_heads"):
# TODO: remove this clause after 2026-01-01
num_heads = self.model.num_heads
else: # changed in https://github.com/huggingface/transformers/pull/35235
num_heads = self.model.config.num_attention_heads
# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, (num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
adapter_v = (
value.view(1, self.adapter_len, (num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
# (bsz, num_heads, adapter_len, head_dim)
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)
previous_dtype = query_states.dtype
# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
self.model.head_dim
)
# Upcast attention to fp32
# (bsz, num_heads, q_len, adapter_len)
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)
# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
# Add adaption prompt output to original output.
output = output + adapter_output
# Restore original dtype.
output = output.to(previous_dtype)
return output, *_