optimum/neuron/models/inference/backend/modules/attention/gqa.py (651 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/attention/gqa.py
import enum
import logging
from typing import Optional, Tuple
import torch
from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.parallel_layers.mappings import gather_from_sequence_parallel_region
from neuronx_distributed.parallel_layers.pad import get_number_of_extra_heads
from neuronxcc.nki._private_kernels.qkv import rmsnorm_qkv_isa_kernel
from neuronxcc.nki.language import nc
from torch import nn
from torch.distributed import ProcessGroup
from torch.nn import functional as F
from torch_neuronx.xla_impl.ops import nki_jit # noqa: E402
from .utils import transpose_parallel_linear_layer
logger = logging.getLogger("Neuron")
_traced_qkv_kernel = nki_jit()(rmsnorm_qkv_isa_kernel)
class GQA(enum.Enum):
# This transforms a GQA attention mechanism into a traditional MHA mechanism
# by replicating the K/V heads to evenly match the corresponding Q heads.
# This consumes more memory than would otherwise be used with other sharding
# mechanisms but works in all cases.
# Example:
# tp_degree = 32
# num_attention_heads: 56 -> 64
# num_kev_value_heads: 8 -> 64
# | K1 K1 | K2 K2 | ... | K7 K7| Pad Pad | ... | Pad Pad |
# | Q1 Q2 | Q3 Q4 | ... | Q55 Q56 | Pad Pad | ... | Pad Pad |
CONVERT_TO_MHA = "convert-to-mha"
# This transforms a GQA attention mechanism such that there is exactly
# one K/V head per tp_degree through replication e.g. 8 K/V heads with
# tp_degree=32 results in 32 K/V heads. This is more memory efficient but
# does not work for all configurations. Q heads are padded interleaved
# to retain correct alignment between Q and K/V heads.
# Example:
# tp_degree = 32
# num_attention_heads: 56 -> 64
# num_kev_value_heads: 8 -> 32
# | K1 | K1 | K1 | K1 | K2 | ...
# | Q1 Q2 | Q3 Q4 | Q5 Q6 | Q7 Pad | Q8 Q9 | ...
REPLICATE_TO_TP_DEGREE = "replicate-to-tp-degree"
def determine_sharding_strategy(
tp_degree: int, source_key_value_heads: int, desired_sharding_strategy: Optional[GQA] = None
) -> GQA:
sharding_strategy = desired_sharding_strategy if desired_sharding_strategy else GQA.REPLICATE_TO_TP_DEGREE
if sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE and (tp_degree % source_key_value_heads != 0):
sharding_strategy = GQA.CONVERT_TO_MHA
return sharding_strategy
def get_shardable_head_counts(
tp_degree: int, num_attention_heads: int, num_key_value_heads: int, sharding_strategy: GQA
) -> Tuple[int, int]:
# Pad attention heads
updated_num_attention_heads = num_attention_heads + get_number_of_extra_heads(num_attention_heads, tp_degree)
# Replicate and pad K/V heads
updated_num_key_value_heads = num_key_value_heads
if num_attention_heads == num_key_value_heads: # MHA
updated_num_key_value_heads = updated_num_attention_heads
else: # GQA / MQA
if (num_key_value_heads < tp_degree) or (num_key_value_heads % tp_degree != 0):
if sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
assert tp_degree % num_key_value_heads == 0, (
"GQA.REPLICATE_TO_TP_DEGREE requires tp_degree to be divisible by num_key_value_heads"
)
updated_num_key_value_heads = tp_degree
elif sharding_strategy == GQA.CONVERT_TO_MHA:
updated_num_key_value_heads = updated_num_attention_heads
return updated_num_attention_heads, updated_num_key_value_heads
def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int):
if tensor is None:
return tensor
# Why we convert FP8 tensor to bfloat16?
# Torch does not support torch.cat, or torch.zeros (for large dimensions) for f8e4m3/f8e5m2
# So we cast it to bfloat16, perform padding, and then recast back to f8e4m3/f8e5m2
recast_dtype = None
if tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
recast_dtype = tensor.dtype
tensor = tensor.to(torch.bfloat16)
shape = (
tensor.shape[:pad_dim] + (source_heads, tensor.shape[pad_dim] // source_heads) + tensor.shape[pad_dim + 1 :]
)
tensor = tensor.view(shape)
splits = torch.split(tensor, source_group_size, dim=pad_dim)
pad_size = list(splits[0].size())
pad_size[pad_dim] = (target_heads - source_heads) // (source_heads // source_group_size)
pads = [torch.zeros(pad_size, dtype=tensor.dtype)] * len(splits)
interleaved = [t for pair in zip(splits, pads) for t in pair]
tensor = torch.cat(interleaved, dim=pad_dim)
shape = tensor.shape[:pad_dim] + (tensor.shape[pad_dim] * tensor.shape[pad_dim + 1],) + tensor.shape[pad_dim + 2 :]
if recast_dtype is not None:
tensor = tensor.to(recast_dtype)
return tensor.view(shape)
def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int):
if tensor is None:
return tensor
size_to_pad = int((tensor.shape[pad_dim] // source_heads) * target_heads - tensor.shape[pad_dim])
dims_after_pad_dim = len(tensor.size()) - pad_dim
pad_length = dims_after_pad_dim * 2
pad = (0,) * (pad_length - 1) + (size_to_pad,)
return F.pad(tensor, pad)
def replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0):
if tensor is None:
return tensor
shape = (
tensor.shape[:head_dim] + (source_heads, tensor.shape[head_dim] // source_heads) + tensor.shape[head_dim + 1 :]
)
tensor = tensor.view(shape)
tensor = torch.repeat_interleave(tensor, repeats=repeats, dim=head_dim)
shape = (
tensor.shape[:head_dim] + (tensor.shape[head_dim] * tensor.shape[head_dim + 1],) + tensor.shape[head_dim + 2 :]
)
return tensor.view(shape)
class BaseGroupQueryAttention(nn.Module):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
desired_sharding_strategy: Optional[GQA] = None,
tensor_model_parallel_group: Optional[ProcessGroup] = None,
):
super().__init__()
if tensor_model_parallel_group is not None:
self.tensor_model_parallel_group = tensor_model_parallel_group
else:
self.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group()
if tensor_model_parallel_group:
if tp_degree == 1:
# update default value
tp_degree = tensor_model_parallel_group.size()
else:
assert tp_degree == self.tensor_model_parallel_group.size(), (
f"TP Degree {tp_degree} and tensor model parallel group size {self.tensor_model_parallel_group.size()} does not match"
)
self.hidden_size = hidden_size
self.tp_degree = tp_degree
self.head_dim = head_dim
self.dtype = dtype
self.bias = bias
self._src_num_attention_heads = num_attention_heads
self._src_num_key_value_heads = num_key_value_heads
self.sharding_strategy = determine_sharding_strategy(
tp_degree,
self._src_num_key_value_heads,
desired_sharding_strategy=desired_sharding_strategy,
)
self.num_attention_heads, self.num_key_value_heads = get_shardable_head_counts(
tp_degree,
self._src_num_attention_heads,
self._src_num_key_value_heads,
self.sharding_strategy,
)
def get_sharding_strategy(self) -> GQA:
return self.sharding_strategy
def get_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_key_value_heads(self) -> int:
return self.num_key_value_heads
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
raise NotImplementedError
def replace_prefixes(self, old_prefix, new_prefix, model_state_dict):
old_keys = []
new_keys = []
for key in model_state_dict.keys():
if old_prefix in key:
new_key = key.replace(old_prefix, new_prefix)
new_keys.append(new_key)
old_keys.append(key)
for key_index in range(len(old_keys)):
model_state_dict[new_keys[key_index]] = model_state_dict.pop(old_keys[key_index])
class GroupQueryAttention_QKV(BaseGroupQueryAttention):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
desired_sharding_strategy: Optional[GQA] = None,
gather_output: bool = True,
fused_qkv: bool = False,
clip_qkv: Optional[float] = None,
sequence_parallel_enabled: bool = False,
sequence_dimension: Optional[int] = None,
tensor_model_parallel_group: Optional[ProcessGroup] = None,
rms_norm_eps: float = None,
qkv_kernel_enabled: bool = False,
logical_nc_config: int = 1,
):
super().__init__(
hidden_size=hidden_size,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
tp_degree=tp_degree,
dtype=dtype,
bias=bias,
desired_sharding_strategy=desired_sharding_strategy,
tensor_model_parallel_group=tensor_model_parallel_group,
)
if fused_qkv and gather_output:
raise ValueError(
"Gathering states followed by fused qkv is not allowed as it has a different weight sharding scheme."
)
self.gather_output = gather_output
self.fused_qkv = fused_qkv
self.clip_qkv = clip_qkv
self.sequence_parallel_enabled = sequence_parallel_enabled
self.sequence_dimension = sequence_dimension
self.rms_norm_eps = rms_norm_eps
self.qkv_kernel_enabled = qkv_kernel_enabled
self.logical_nc_config = logical_nc_config
if self.tensor_model_parallel_group is not None:
if self.fused_qkv:
self.Wqkv = ColumnParallelLinear(
self.hidden_size,
(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
if self.qkv_kernel_enabled:
# we need to transpose the weights on the CPU side to avoid
# needing to transpose on the device when using QKV kernel
self.Wqkv.weight = transpose_parallel_linear_layer(self.Wqkv.weight)
# Set heads info as weight parameter attributes to be used in weights sharding
setattr(self.Wqkv.weight, "fused_qkv", True)
setattr(self.Wqkv.weight, "num_attention_heads", self.num_attention_heads)
setattr(self.Wqkv.weight, "num_key_value_heads", self.num_key_value_heads)
setattr(self.Wqkv.weight, "head_dim", self.head_dim)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_attention_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
sequence_parallel_enabled=False,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
sequence_parallel_enabled=False,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
sequence_parallel_enabled=False,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
else:
if self.fused_qkv:
self.Wqkv = nn.Linear(
self.hidden_size,
(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=self.bias,
)
else:
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)
def forward(self, hidden_states: torch.Tensor, rmsnorm=None):
if self.sequence_parallel_enabled and self.tensor_model_parallel_group is not None:
hidden_states = gather_from_sequence_parallel_region(
hidden_states,
self.sequence_dimension,
process_group=self.tensor_model_parallel_group,
)
if self.qkv_kernel_enabled:
assert self.fused_qkv, "QKV kernel only supported when fused_qkv is TRUE"
fused_rmsnorm = not self.sequence_parallel_enabled
return self._kernel_qkv_forward(hidden_states, fused_rmsnorm, rmsnorm)
else:
return self._native_qkv_forward(hidden_states)
def _native_qkv_forward(self, hidden_states: torch.Tensor):
if self.fused_qkv:
logger.debug("QKV: native compiler")
QKV = self.Wqkv(hidden_states)
return self._split_fused_qkv(QKV)
else:
Q = self.q_proj(hidden_states)
K = self.k_proj(hidden_states)
V = self.v_proj(hidden_states)
if self.clip_qkv is not None:
Q = Q.clamp(min=-self.clip_qkv, max=self.clip_qkv)
K = K.clamp(min=-self.clip_qkv, max=self.clip_qkv)
V = V.clamp(min=-self.clip_qkv, max=self.clip_qkv)
return Q, K, V
def _split_fused_qkv(self, QKV):
logger.debug(f"Fused QKV tensor has shape {QKV.shape}")
if self.clip_qkv is not None:
QKV = QKV.clamp(min=-self.clip_qkv, max=self.clip_qkv)
# shape of QKV is [batch, seqlen, fused_qkv_size]
# we split the fused QKV (dim=2) into Q, K, V
# for example:
# for 405B, TP=128, num_att_heads=128
# LNC=2/TP=64 will split QKV from [batch, seqlen, 512] into:
# Q [batch, seqlen, 256]
# K [batch, seqlen, 128]
# V [batch, seqlen, 128]
# torch.split has accuracy issue and leads to more reshapes in hlo.
# Using torch.tensor_split here. NAPP-3145
q_end_index = self.num_attention_heads * self.head_dim // self.tp_degree
k_end_index = q_end_index + self.num_key_value_heads * self.head_dim // self.tp_degree
Q, K, V = torch.tensor_split(
QKV,
(
q_end_index,
k_end_index,
# rest of the QKV will go to V output
),
dim=2,
)
logger.debug(f"QKV shape before tensor_split: {QKV.shape}")
logger.debug(f"Q shape after tensor_split: {Q.shape}")
logger.debug(f"K shape after tensor_split: {K.shape}")
logger.debug(f"V shape after tensor_split: {V.shape}")
return Q, K, V
def _kernel_qkv_forward(self, hidden_states, fused_rmsnorm, rmsnorm):
logger.debug(f"QKV kernel: fused_rmsnorm={fused_rmsnorm} logical_nc_config={self.logical_nc_config}")
bs, seqlen, h = hidden_states.shape
h2, fused_qkv_size = self.Wqkv.weight.shape
logger.debug(f"fused QKV projection weight - shape: {self.Wqkv.weight.shape}, dtype: {self.Wqkv.weight.dtype}")
# shape checks
assert (
fused_qkv_size
== (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim // self.tp_degree
)
assert h == h2
QKV = torch.zeros(
bs,
seqlen,
fused_qkv_size,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
grid = (nc(self.logical_nc_config),)
# the QKV kernel will automatically switch to the TKG QKV if seqlen==1
_traced_qkv_kernel[grid](
hidden_states,
self.Wqkv.weight,
# unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden]
# should be fine to pass this is as a dummy even if not using fused rmsnorm
rmsnorm.weight.unsqueeze(0) if rmsnorm else torch.ones((1, h), device=hidden_states.device),
QKV,
eps=self.rms_norm_eps,
kernel_name="QKV",
# Run RMSNorm inside the kernel if NOT using SP norm
fused_rmsnorm=(fused_rmsnorm and rmsnorm is not None),
)
return self._split_fused_qkv(QKV)
def get_weight(
self, prefix: str, layer: torch.nn.Module, layer_name, model_state_dict: dict
) -> Tuple[torch.Tensor]:
if hasattr(layer, "get_weight_from_state_dict"):
return layer.get_weight_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
return model_state_dict[f"{prefix}.{layer_name}.weight"]
def get_bias(
self, prefix: str, layer: torch.nn.Module, layer_name: str, model_state_dict: dict
) -> Tuple[torch.Tensor]:
if hasattr(layer, "get_bias_from_state_dict"):
return layer.get_bias_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
return model_state_dict.get(f"{prefix}.{layer_name}.bias")
def set_weight(
self,
tensor: torch.Tensor,
prefix: str,
layer: torch.nn.Module,
layer_name,
model_state_dict: dict,
) -> Tuple[torch.Tensor]:
# TODO: set weight to state dict support is pending.
model_state_dict[f"{prefix}.{layer_name}.weight"] = tensor
def set_bias(
self,
tensor: torch.Tensor,
prefix: str,
layer: torch.nn.Module,
layer_name: str,
model_state_dict: dict,
) -> Tuple[torch.Tensor]:
if hasattr(layer, "set_bias_to_state_dict"):
layer.set_bias_to_state_dict(prefix=f"{prefix}.{layer_name}.", tensor=tensor, state_dict=model_state_dict)
else:
model_state_dict[f"{prefix}.{layer_name}.bias"] = tensor
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
prefix_parts = prefix.split(".")
prefix = ".".join(prefix_parts[:-1])
hf_prefix = ".".join(prefix_parts[:-2])
if self.fused_qkv:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.Wqkv",
new_prefix=f"{prefix}.Wqkv",
model_state_dict=model_state_dict,
)
qkv_weight = self.get_weight(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
qkv_bias = self.get_bias(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
if qkv_bias is not None:
q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
else:
q_proj_bias, k_proj_bias, v_proj_bias = None, None, None
else:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.q_proj",
new_prefix=f"{prefix}.q_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.k_proj",
new_prefix=f"{prefix}.k_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.v_proj",
new_prefix=f"{prefix}.v_proj",
model_state_dict=model_state_dict,
)
q_proj_weight = self.get_weight(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_weight = self.get_weight(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_weight = self.get_weight(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
q_proj_bias = self.get_bias(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_bias = self.get_bias(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_bias = self.get_bias(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.num_key_value_heads != self._src_num_key_value_heads:
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
repeats = self.tp_degree // self._src_num_key_value_heads
elif self.sharding_strategy == GQA.CONVERT_TO_MHA:
repeats = self._src_num_attention_heads // self._src_num_key_value_heads
k_proj_weight = replicate_kv(
k_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
k_proj_bias = replicate_kv(
k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
v_proj_weight = replicate_kv(
v_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
v_proj_bias = replicate_kv(
v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
q_proj_weight = maybe_pad_interleaved(
q_proj_weight,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
q_proj_bias = maybe_pad_interleaved(
q_proj_bias,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
if self.sharding_strategy == GQA.CONVERT_TO_MHA:
q_proj_weight = maybe_pad_tail(
q_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
q_proj_bias = maybe_pad_tail(
q_proj_bias,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
k_proj_weight = maybe_pad_tail(
k_proj_weight,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
k_proj_bias = maybe_pad_tail(
k_proj_bias,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
v_proj_weight = maybe_pad_tail(
v_proj_weight,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
v_proj_bias = maybe_pad_tail(
v_proj_bias,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
if self.fused_qkv:
qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
self.set_weight(
tensor=qkv_weight,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
if self.bias:
qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0)
self.set_bias(
tensor=qkv_bias,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
else:
self.set_weight(
tensor=q_proj_weight,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=k_proj_weight,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=v_proj_weight,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.bias:
self.set_bias(
tensor=q_proj_bias,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=k_proj_bias,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=v_proj_bias,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
return True
class GroupQueryAttention_O(BaseGroupQueryAttention):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
desired_sharding_strategy: Optional[GQA] = None,
input_is_parallel: bool = False,
layer_name: str = "o_proj",
sequence_parallel_enabled: bool = False,
sequence_dimension: Optional[int] = None,
tensor_model_parallel_group: Optional[ProcessGroup] = None,
rpl_reduce_dtype: torch.dtype = None,
):
super().__init__(
hidden_size=hidden_size,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
tp_degree=tp_degree,
dtype=dtype,
bias=bias,
desired_sharding_strategy=desired_sharding_strategy,
tensor_model_parallel_group=tensor_model_parallel_group,
)
self.input_is_parallel = input_is_parallel
if self.tensor_model_parallel_group is not None:
self.o_proj = RowParallelLinear(
self.num_attention_heads * self.head_dim,
self.hidden_size,
bias=self.bias,
input_is_parallel=self.input_is_parallel,
dtype=self.dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
sequence_dimension=sequence_dimension,
tensor_model_parallel_group=self.tensor_model_parallel_group,
reduce_dtype=rpl_reduce_dtype,
)
else:
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.bias)
# Prepared for changing "o_proj" to the corresponding name in model_state_dict
# For example, in CLIP vision model, we use "out_proj"
self.layer_name = layer_name
def forward(self, attention_output: torch.Tensor):
return self.o_proj(attention_output)
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
prefix_parts = prefix.split(".")
prefix = ".".join(prefix_parts[:-1])
hf_prefix = ".".join(prefix_parts[:-2])
self.replace_prefixes(
old_prefix=f"{hf_prefix}.{self.layer_name}",
new_prefix=f"{prefix}.o_proj",
model_state_dict=model_state_dict,
)
o_proj_weight = model_state_dict[f"{prefix}.o_proj.weight"]
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
o_proj_weight = maybe_pad_interleaved(
o_proj_weight,
pad_dim=1,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
if self.sharding_strategy == GQA.CONVERT_TO_MHA:
o_proj_weight = maybe_pad_tail(
o_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=1,
)
model_state_dict[f"{prefix}.o_proj.weight"] = o_proj_weight
return True