optimum/habana/peft/layer.py (135 lines of code) (raw):
import inspect
import math
from typing import Any
import torch
import torch.nn.functional as F
from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG
from peft.tuners.adaption_prompt.utils import llama_apply_rotary_pos_emb, llama_rotate_half
def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Copied from SVDLinear.forward: https://github.com/huggingface/peft/blob/v0.9.0/src/peft/tuners/adalora/layer.py#L158
The only differences are:
- fix batch_gemm failure for BF16 case
"""
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5
x = x.to(lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * (scaling / ranknum)
return result
def GaudiPolyLayerLinearForward(
self, x: torch.Tensor, *args: Any, task_ids: torch.Tensor = None, **kwargs: Any
) -> torch.Tensor:
"""
Copied from Linear.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/poly/layer.py#L135
The only differences are:
- /r equal to *(1.0/r). /r makes batch_gemm BF16 failure
"""
previous_dtype = x.dtype
if self.disable_adapters:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.poly_lora_A.keys():
continue
r = self.r[active_adapter]
poly_router = self.poly_router[active_adapter]
poly_lora_A = self.poly_lora_A[active_adapter]
poly_lora_B = self.poly_lora_B[active_adapter]
# Combine the output of LoRAs
# https://github.com/microsoft/mttl/blob/ce4ca51dbca73be656feb9b3e5233633e3c5dec7/mttl/models/poly.py#L293
mixing_weights = poly_router(task_ids=task_ids, input_ids=x)
bs, n_splits, n_skills = mixing_weights.size()
# A is n_splits, n_skills, D // n_splits, rank
# we want bs, n_splits, D // n_splits, rank
A = torch.einsum("bqs,qsdr->bqdr", (mixing_weights, poly_lora_A))
B = torch.einsum("bqs,qsrd->bqrd", (mixing_weights, poly_lora_B))
A = A.reshape(bs, self.in_features, r)
B = B.transpose(1, 2).reshape(bs, r, self.out_features)
x = x.to(A.dtype)
result += x.bmm(A).bmm(B) * (1.0 / r)
result = result.to(previous_dtype)
return result
def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
"""
Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
The only differences are:
-add reuse cache support.
-add past key value list support
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)
seq_len = q_len
if past_key_value is not None:
if kwargs.get("reuse_cache", False):
seq_len += past_key_value[0][-2]
elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)
rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens
cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)
# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
if len(cos.shape) == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)
def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
"""
Copied from AdaptedAttention.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/layer.py#L57
The only differences are:
- replace self.model() with self.model.pre_attn_forward()
"""
if kwargs.get("output_attention", False):
raise NotImplementedError("output_attention is not currently supported.")
output, _, past_key_value = self.model.pre_attn_forward(*args, **kwargs)
bsz = output.shape[0]
q_len = output.shape[1]
embed_dim = output.shape[2]
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers
if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
adapter_v = (
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
# (bsz, num_heads, adapter_len, head_dim)
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)
previous_dtype = query_states.dtype
# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(self.model.head_dim)
# Upcast attention to fp32
# (bsz, num_heads, q_len, adapter_len)
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)
# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
# Add adaption prompt output to original output.
output = output + adapter_output
# Restore original dtype.
output = output.to(previous_dtype)
return output, None, past_key_value
def GaudiAdaptedAttention_getattr(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super(self.__class__, self).__getattr__(name)
except AttributeError:
# This is necessary as e.g. causal models have various methods that we
# don't want to re-implement here.
return getattr(self.model, name)