# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/attention/gqa.py
import enum
import logging
from typing import Optional, Tuple

import torch
from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.parallel_layers.mappings import gather_from_sequence_parallel_region
from neuronx_distributed.parallel_layers.pad import get_number_of_extra_heads
from neuronxcc.nki._private_kernels.qkv import rmsnorm_qkv_isa_kernel
from neuronxcc.nki.language import nc
from torch import nn
from torch.distributed import ProcessGroup
from torch.nn import functional as F
from torch_neuronx.xla_impl.ops import nki_jit  # noqa: E402

from .utils import transpose_parallel_linear_layer


logger = logging.getLogger("Neuron")

_traced_qkv_kernel = nki_jit()(rmsnorm_qkv_isa_kernel)


class GQA(enum.Enum):
    # This transforms a GQA attention mechanism into a traditional MHA mechanism
    # by replicating the K/V heads to evenly match the corresponding Q heads.
    # This consumes more memory than would otherwise be used with other sharding
    # mechanisms but works in all cases.
    # Example:
    # tp_degree = 32
    # num_attention_heads: 56 -> 64
    # num_kev_value_heads: 8  -> 64
    # | K1 K1 | K2 K2 | ... | K7 K7| Pad Pad | ... | Pad Pad |
    # | Q1 Q2 | Q3 Q4 | ... | Q55 Q56 | Pad Pad | ... | Pad Pad |
    CONVERT_TO_MHA = "convert-to-mha"

    # This transforms a GQA attention mechanism such that there is exactly
    # one K/V head per tp_degree through replication e.g. 8 K/V heads with
    # tp_degree=32 results in 32 K/V heads. This is more memory efficient but
    # does not work for all configurations. Q heads are padded interleaved
    # to retain correct alignment between Q and K/V heads.
    # Example:
    # tp_degree = 32
    # num_attention_heads: 56 -> 64
    # num_kev_value_heads: 8  -> 32
    # | K1    | K1    | K1    | K1     | K2    | ...
    # | Q1 Q2 | Q3 Q4 | Q5 Q6 | Q7 Pad | Q8 Q9 | ...
    REPLICATE_TO_TP_DEGREE = "replicate-to-tp-degree"


def determine_sharding_strategy(
    tp_degree: int, source_key_value_heads: int, desired_sharding_strategy: Optional[GQA] = None
) -> GQA:
    sharding_strategy = desired_sharding_strategy if desired_sharding_strategy else GQA.REPLICATE_TO_TP_DEGREE

    if sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE and (tp_degree % source_key_value_heads != 0):
        sharding_strategy = GQA.CONVERT_TO_MHA

    return sharding_strategy


def get_shardable_head_counts(
    tp_degree: int, num_attention_heads: int, num_key_value_heads: int, sharding_strategy: GQA
) -> Tuple[int, int]:
    # Pad attention heads
    updated_num_attention_heads = num_attention_heads + get_number_of_extra_heads(num_attention_heads, tp_degree)

    # Replicate and pad K/V heads
    updated_num_key_value_heads = num_key_value_heads
    if num_attention_heads == num_key_value_heads:  # MHA
        updated_num_key_value_heads = updated_num_attention_heads
    else:  # GQA / MQA
        if (num_key_value_heads < tp_degree) or (num_key_value_heads % tp_degree != 0):
            if sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
                assert tp_degree % num_key_value_heads == 0, (
                    "GQA.REPLICATE_TO_TP_DEGREE requires tp_degree to be divisible by num_key_value_heads"
                )
                updated_num_key_value_heads = tp_degree
            elif sharding_strategy == GQA.CONVERT_TO_MHA:
                updated_num_key_value_heads = updated_num_attention_heads

    return updated_num_attention_heads, updated_num_key_value_heads


def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int):
    if tensor is None:
        return tensor

    # Why we convert FP8 tensor to bfloat16?
    # Torch does not support torch.cat, or torch.zeros (for large dimensions) for f8e4m3/f8e5m2
    # So we cast it to bfloat16, perform padding, and then recast back to f8e4m3/f8e5m2
    recast_dtype = None
    if tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
        recast_dtype = tensor.dtype
        tensor = tensor.to(torch.bfloat16)

    shape = (
        tensor.shape[:pad_dim] + (source_heads, tensor.shape[pad_dim] // source_heads) + tensor.shape[pad_dim + 1 :]
    )
    tensor = tensor.view(shape)

    splits = torch.split(tensor, source_group_size, dim=pad_dim)

    pad_size = list(splits[0].size())
    pad_size[pad_dim] = (target_heads - source_heads) // (source_heads // source_group_size)
    pads = [torch.zeros(pad_size, dtype=tensor.dtype)] * len(splits)

    interleaved = [t for pair in zip(splits, pads) for t in pair]
    tensor = torch.cat(interleaved, dim=pad_dim)

    shape = tensor.shape[:pad_dim] + (tensor.shape[pad_dim] * tensor.shape[pad_dim + 1],) + tensor.shape[pad_dim + 2 :]

    if recast_dtype is not None:
        tensor = tensor.to(recast_dtype)

    return tensor.view(shape)


def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int):
    if tensor is None:
        return tensor
    size_to_pad = int((tensor.shape[pad_dim] // source_heads) * target_heads - tensor.shape[pad_dim])

    dims_after_pad_dim = len(tensor.size()) - pad_dim
    pad_length = dims_after_pad_dim * 2
    pad = (0,) * (pad_length - 1) + (size_to_pad,)

    return F.pad(tensor, pad)


def replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0):
    if tensor is None:
        return tensor
    shape = (
        tensor.shape[:head_dim] + (source_heads, tensor.shape[head_dim] // source_heads) + tensor.shape[head_dim + 1 :]
    )
    tensor = tensor.view(shape)
    tensor = torch.repeat_interleave(tensor, repeats=repeats, dim=head_dim)
    shape = (
        tensor.shape[:head_dim] + (tensor.shape[head_dim] * tensor.shape[head_dim + 1],) + tensor.shape[head_dim + 2 :]
    )
    return tensor.view(shape)


class BaseGroupQueryAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        head_dim: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        tp_degree: int = 1,
        dtype: torch.dtype = torch.float32,
        bias: bool = False,
        desired_sharding_strategy: Optional[GQA] = None,
        tensor_model_parallel_group: Optional[ProcessGroup] = None,
    ):
        super().__init__()

        if tensor_model_parallel_group is not None:
            self.tensor_model_parallel_group = tensor_model_parallel_group
        else:
            self.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group()

        if tensor_model_parallel_group:
            if tp_degree == 1:
                # update default value
                tp_degree = tensor_model_parallel_group.size()
            else:
                assert tp_degree == self.tensor_model_parallel_group.size(), (
                    f"TP Degree {tp_degree} and tensor model parallel group size {self.tensor_model_parallel_group.size()} does not match"
                )

        self.hidden_size = hidden_size
        self.tp_degree = tp_degree
        self.head_dim = head_dim
        self.dtype = dtype
        self.bias = bias
        self._src_num_attention_heads = num_attention_heads
        self._src_num_key_value_heads = num_key_value_heads

        self.sharding_strategy = determine_sharding_strategy(
            tp_degree,
            self._src_num_key_value_heads,
            desired_sharding_strategy=desired_sharding_strategy,
        )
        self.num_attention_heads, self.num_key_value_heads = get_shardable_head_counts(
            tp_degree,
            self._src_num_attention_heads,
            self._src_num_key_value_heads,
            self.sharding_strategy,
        )

    def get_sharding_strategy(self) -> GQA:
        return self.sharding_strategy

    def get_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_key_value_heads(self) -> int:
        return self.num_key_value_heads

    def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
        raise NotImplementedError

    def replace_prefixes(self, old_prefix, new_prefix, model_state_dict):
        old_keys = []
        new_keys = []
        for key in model_state_dict.keys():
            if old_prefix in key:
                new_key = key.replace(old_prefix, new_prefix)
                new_keys.append(new_key)
                old_keys.append(key)

        for key_index in range(len(old_keys)):
            model_state_dict[new_keys[key_index]] = model_state_dict.pop(old_keys[key_index])


class GroupQueryAttention_QKV(BaseGroupQueryAttention):
    def __init__(
        self,
        hidden_size: int,
        head_dim: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        tp_degree: int = 1,
        dtype: torch.dtype = torch.float32,
        bias: bool = False,
        desired_sharding_strategy: Optional[GQA] = None,
        gather_output: bool = True,
        fused_qkv: bool = False,
        clip_qkv: Optional[float] = None,
        sequence_parallel_enabled: bool = False,
        sequence_dimension: Optional[int] = None,
        tensor_model_parallel_group: Optional[ProcessGroup] = None,
        rms_norm_eps: float = None,
        qkv_kernel_enabled: bool = False,
        logical_nc_config: int = 1,
    ):
        super().__init__(
            hidden_size=hidden_size,
            head_dim=head_dim,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            tp_degree=tp_degree,
            dtype=dtype,
            bias=bias,
            desired_sharding_strategy=desired_sharding_strategy,
            tensor_model_parallel_group=tensor_model_parallel_group,
        )
        if fused_qkv and gather_output:
            raise ValueError(
                "Gathering states followed by fused qkv is not allowed as it has a different weight sharding scheme."
            )

        self.gather_output = gather_output
        self.fused_qkv = fused_qkv
        self.clip_qkv = clip_qkv

        self.sequence_parallel_enabled = sequence_parallel_enabled
        self.sequence_dimension = sequence_dimension
        self.rms_norm_eps = rms_norm_eps
        self.qkv_kernel_enabled = qkv_kernel_enabled
        self.logical_nc_config = logical_nc_config

        if self.tensor_model_parallel_group is not None:
            if self.fused_qkv:
                self.Wqkv = ColumnParallelLinear(
                    self.hidden_size,
                    (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
                    bias=self.bias,
                    gather_output=self.gather_output,
                    dtype=dtype,
                    tensor_model_parallel_group=self.tensor_model_parallel_group,
                )
                if self.qkv_kernel_enabled:
                    # we need to transpose the weights on the CPU side to avoid
                    # needing to transpose on the device when using QKV kernel
                    self.Wqkv.weight = transpose_parallel_linear_layer(self.Wqkv.weight)

                # Set heads info as weight parameter attributes to be used in weights sharding
                setattr(self.Wqkv.weight, "fused_qkv", True)
                setattr(self.Wqkv.weight, "num_attention_heads", self.num_attention_heads)
                setattr(self.Wqkv.weight, "num_key_value_heads", self.num_key_value_heads)
                setattr(self.Wqkv.weight, "head_dim", self.head_dim)

            else:
                self.q_proj = ColumnParallelLinear(
                    self.hidden_size,
                    self.num_attention_heads * self.head_dim,
                    bias=self.bias,
                    gather_output=self.gather_output,
                    dtype=dtype,
                    sequence_parallel_enabled=False,
                    tensor_model_parallel_group=self.tensor_model_parallel_group,
                )
                self.k_proj = ColumnParallelLinear(
                    self.hidden_size,
                    self.num_key_value_heads * self.head_dim,
                    bias=self.bias,
                    gather_output=self.gather_output,
                    dtype=dtype,
                    sequence_parallel_enabled=False,
                    tensor_model_parallel_group=self.tensor_model_parallel_group,
                )
                self.v_proj = ColumnParallelLinear(
                    self.hidden_size,
                    self.num_key_value_heads * self.head_dim,
                    bias=self.bias,
                    gather_output=self.gather_output,
                    dtype=dtype,
                    sequence_parallel_enabled=False,
                    tensor_model_parallel_group=self.tensor_model_parallel_group,
                )
        else:
            if self.fused_qkv:
                self.Wqkv = nn.Linear(
                    self.hidden_size,
                    (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
                    bias=self.bias,
                )
            else:
                self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.bias)
                self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)
                self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)

    def forward(self, hidden_states: torch.Tensor, rmsnorm=None):
        if self.sequence_parallel_enabled and self.tensor_model_parallel_group is not None:
            hidden_states = gather_from_sequence_parallel_region(
                hidden_states,
                self.sequence_dimension,
                process_group=self.tensor_model_parallel_group,
            )

        if self.qkv_kernel_enabled:
            assert self.fused_qkv, "QKV kernel only supported when fused_qkv is TRUE"
            fused_rmsnorm = not self.sequence_parallel_enabled
            return self._kernel_qkv_forward(hidden_states, fused_rmsnorm, rmsnorm)
        else:
            return self._native_qkv_forward(hidden_states)

    def _native_qkv_forward(self, hidden_states: torch.Tensor):
        if self.fused_qkv:
            logger.debug("QKV: native compiler")
            QKV = self.Wqkv(hidden_states)
            return self._split_fused_qkv(QKV)
        else:
            Q = self.q_proj(hidden_states)
            K = self.k_proj(hidden_states)
            V = self.v_proj(hidden_states)
            if self.clip_qkv is not None:
                Q = Q.clamp(min=-self.clip_qkv, max=self.clip_qkv)
                K = K.clamp(min=-self.clip_qkv, max=self.clip_qkv)
                V = V.clamp(min=-self.clip_qkv, max=self.clip_qkv)
            return Q, K, V

    def _split_fused_qkv(self, QKV):
        logger.debug(f"Fused QKV tensor has shape {QKV.shape}")
        if self.clip_qkv is not None:
            QKV = QKV.clamp(min=-self.clip_qkv, max=self.clip_qkv)

        # shape of QKV is [batch, seqlen, fused_qkv_size]
        # we split the fused QKV (dim=2) into Q, K, V
        # for example:
        #   for 405B, TP=128, num_att_heads=128
        #   LNC=2/TP=64 will split QKV from [batch, seqlen, 512] into:
        #   Q [batch, seqlen, 256]
        #   K [batch, seqlen, 128]
        #   V [batch, seqlen, 128]
        # torch.split has accuracy issue and leads to more reshapes in hlo.
        # Using torch.tensor_split here. NAPP-3145
        q_end_index = self.num_attention_heads * self.head_dim // self.tp_degree
        k_end_index = q_end_index + self.num_key_value_heads * self.head_dim // self.tp_degree
        Q, K, V = torch.tensor_split(
            QKV,
            (
                q_end_index,
                k_end_index,
                # rest of the QKV will go to V output
            ),
            dim=2,
        )
        logger.debug(f"QKV shape before tensor_split: {QKV.shape}")
        logger.debug(f"Q shape after tensor_split: {Q.shape}")
        logger.debug(f"K shape after tensor_split: {K.shape}")
        logger.debug(f"V shape after tensor_split: {V.shape}")
        return Q, K, V

    def _kernel_qkv_forward(self, hidden_states, fused_rmsnorm, rmsnorm):
        logger.debug(f"QKV kernel: fused_rmsnorm={fused_rmsnorm} logical_nc_config={self.logical_nc_config}")
        bs, seqlen, h = hidden_states.shape

        h2, fused_qkv_size = self.Wqkv.weight.shape
        logger.debug(f"fused QKV projection weight - shape: {self.Wqkv.weight.shape}, dtype: {self.Wqkv.weight.dtype}")

        # shape checks
        assert (
            fused_qkv_size
            == (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim // self.tp_degree
        )
        assert h == h2

        QKV = torch.zeros(
            bs,
            seqlen,
            fused_qkv_size,
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        grid = (nc(self.logical_nc_config),)

        # the QKV kernel will automatically switch to the TKG QKV if seqlen==1
        _traced_qkv_kernel[grid](
            hidden_states,
            self.Wqkv.weight,
            # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden]
            # should be fine to pass this is as a dummy even if not using fused rmsnorm
            rmsnorm.weight.unsqueeze(0) if rmsnorm else torch.ones((1, h), device=hidden_states.device),
            QKV,
            eps=self.rms_norm_eps,
            kernel_name="QKV",
            # Run RMSNorm inside the kernel if NOT using SP norm
            fused_rmsnorm=(fused_rmsnorm and rmsnorm is not None),
        )
        return self._split_fused_qkv(QKV)

    def get_weight(
        self, prefix: str, layer: torch.nn.Module, layer_name, model_state_dict: dict
    ) -> Tuple[torch.Tensor]:
        if hasattr(layer, "get_weight_from_state_dict"):
            return layer.get_weight_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
        return model_state_dict[f"{prefix}.{layer_name}.weight"]

    def get_bias(
        self, prefix: str, layer: torch.nn.Module, layer_name: str, model_state_dict: dict
    ) -> Tuple[torch.Tensor]:
        if hasattr(layer, "get_bias_from_state_dict"):
            return layer.get_bias_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
        return model_state_dict.get(f"{prefix}.{layer_name}.bias")

    def set_weight(
        self,
        tensor: torch.Tensor,
        prefix: str,
        layer: torch.nn.Module,
        layer_name,
        model_state_dict: dict,
    ) -> Tuple[torch.Tensor]:
        # TODO: set weight to state dict support is pending.
        model_state_dict[f"{prefix}.{layer_name}.weight"] = tensor

    def set_bias(
        self,
        tensor: torch.Tensor,
        prefix: str,
        layer: torch.nn.Module,
        layer_name: str,
        model_state_dict: dict,
    ) -> Tuple[torch.Tensor]:
        if hasattr(layer, "set_bias_to_state_dict"):
            layer.set_bias_to_state_dict(prefix=f"{prefix}.{layer_name}.", tensor=tensor, state_dict=model_state_dict)
        else:
            model_state_dict[f"{prefix}.{layer_name}.bias"] = tensor

    def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
        prefix_parts = prefix.split(".")
        prefix = ".".join(prefix_parts[:-1])
        hf_prefix = ".".join(prefix_parts[:-2])
        if self.fused_qkv:
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.Wqkv",
                new_prefix=f"{prefix}.Wqkv",
                model_state_dict=model_state_dict,
            )
            qkv_weight = self.get_weight(
                prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
            )
            q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split(
                [
                    self._src_num_attention_heads * self.head_dim,
                    self._src_num_key_value_heads * self.head_dim,
                    self._src_num_key_value_heads * self.head_dim,
                ],
                dim=0,
            )
            qkv_bias = self.get_bias(
                prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
            )
            if qkv_bias is not None:
                q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split(
                    [
                        self._src_num_attention_heads * self.head_dim,
                        self._src_num_key_value_heads * self.head_dim,
                        self._src_num_key_value_heads * self.head_dim,
                    ],
                    dim=0,
                )
            else:
                q_proj_bias, k_proj_bias, v_proj_bias = None, None, None
        else:
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.q_proj",
                new_prefix=f"{prefix}.q_proj",
                model_state_dict=model_state_dict,
            )
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.k_proj",
                new_prefix=f"{prefix}.k_proj",
                model_state_dict=model_state_dict,
            )
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.v_proj",
                new_prefix=f"{prefix}.v_proj",
                model_state_dict=model_state_dict,
            )

            q_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            k_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            v_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

            q_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            k_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            v_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

        if self.num_key_value_heads != self._src_num_key_value_heads:
            if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
                repeats = self.tp_degree // self._src_num_key_value_heads
            elif self.sharding_strategy == GQA.CONVERT_TO_MHA:
                repeats = self._src_num_attention_heads // self._src_num_key_value_heads
            k_proj_weight = replicate_kv(
                k_proj_weight,
                source_heads=self._src_num_key_value_heads,
                repeats=repeats,
                head_dim=0,
            )
            k_proj_bias = replicate_kv(
                k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
            )
            v_proj_weight = replicate_kv(
                v_proj_weight,
                source_heads=self._src_num_key_value_heads,
                repeats=repeats,
                head_dim=0,
            )
            v_proj_bias = replicate_kv(
                v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
            )

        if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
            q_proj_weight = maybe_pad_interleaved(
                q_proj_weight,
                pad_dim=0,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
            )
            q_proj_bias = maybe_pad_interleaved(
                q_proj_bias,
                pad_dim=0,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
            )

        if self.sharding_strategy == GQA.CONVERT_TO_MHA:
            q_proj_weight = maybe_pad_tail(
                q_proj_weight,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                pad_dim=0,
            )
            q_proj_bias = maybe_pad_tail(
                q_proj_bias,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                pad_dim=0,
            )
            k_proj_weight = maybe_pad_tail(
                k_proj_weight,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            k_proj_bias = maybe_pad_tail(
                k_proj_bias,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            v_proj_weight = maybe_pad_tail(
                v_proj_weight,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            v_proj_bias = maybe_pad_tail(
                v_proj_bias,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )

        if self.fused_qkv:
            qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
            self.set_weight(
                tensor=qkv_weight,
                prefix=prefix,
                layer=self.Wqkv,
                layer_name="Wqkv",
                model_state_dict=model_state_dict,
            )
            if self.bias:
                qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0)
                self.set_bias(
                    tensor=qkv_bias,
                    prefix=prefix,
                    layer=self.Wqkv,
                    layer_name="Wqkv",
                    model_state_dict=model_state_dict,
                )
        else:
            self.set_weight(
                tensor=q_proj_weight,
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            self.set_weight(
                tensor=k_proj_weight,
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            self.set_weight(
                tensor=v_proj_weight,
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

            if self.bias:
                self.set_bias(
                    tensor=q_proj_bias,
                    prefix=prefix,
                    layer=self.q_proj,
                    layer_name="q_proj",
                    model_state_dict=model_state_dict,
                )
                self.set_bias(
                    tensor=k_proj_bias,
                    prefix=prefix,
                    layer=self.k_proj,
                    layer_name="k_proj",
                    model_state_dict=model_state_dict,
                )
                self.set_bias(
                    tensor=v_proj_bias,
                    prefix=prefix,
                    layer=self.v_proj,
                    layer_name="v_proj",
                    model_state_dict=model_state_dict,
                )

        return True


class GroupQueryAttention_O(BaseGroupQueryAttention):
    def __init__(
        self,
        hidden_size: int,
        head_dim: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        tp_degree: int = 1,
        dtype: torch.dtype = torch.float32,
        bias: bool = False,
        desired_sharding_strategy: Optional[GQA] = None,
        input_is_parallel: bool = False,
        layer_name: str = "o_proj",
        sequence_parallel_enabled: bool = False,
        sequence_dimension: Optional[int] = None,
        tensor_model_parallel_group: Optional[ProcessGroup] = None,
        rpl_reduce_dtype: torch.dtype = None,
    ):
        super().__init__(
            hidden_size=hidden_size,
            head_dim=head_dim,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            tp_degree=tp_degree,
            dtype=dtype,
            bias=bias,
            desired_sharding_strategy=desired_sharding_strategy,
            tensor_model_parallel_group=tensor_model_parallel_group,
        )

        self.input_is_parallel = input_is_parallel

        if self.tensor_model_parallel_group is not None:
            self.o_proj = RowParallelLinear(
                self.num_attention_heads * self.head_dim,
                self.hidden_size,
                bias=self.bias,
                input_is_parallel=self.input_is_parallel,
                dtype=self.dtype,
                sequence_parallel_enabled=sequence_parallel_enabled,
                sequence_dimension=sequence_dimension,
                tensor_model_parallel_group=self.tensor_model_parallel_group,
                reduce_dtype=rpl_reduce_dtype,
            )
        else:
            self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.bias)

        # Prepared for changing "o_proj" to the corresponding name in model_state_dict
        # For example, in CLIP vision model, we use "out_proj"
        self.layer_name = layer_name

    def forward(self, attention_output: torch.Tensor):
        return self.o_proj(attention_output)

    def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
        prefix_parts = prefix.split(".")
        prefix = ".".join(prefix_parts[:-1])
        hf_prefix = ".".join(prefix_parts[:-2])

        self.replace_prefixes(
            old_prefix=f"{hf_prefix}.{self.layer_name}",
            new_prefix=f"{prefix}.o_proj",
            model_state_dict=model_state_dict,
        )
        o_proj_weight = model_state_dict[f"{prefix}.o_proj.weight"]

        if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
            o_proj_weight = maybe_pad_interleaved(
                o_proj_weight,
                pad_dim=1,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
            )

        if self.sharding_strategy == GQA.CONVERT_TO_MHA:
            o_proj_weight = maybe_pad_tail(
                o_proj_weight,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                pad_dim=1,
            )

        model_state_dict[f"{prefix}.o_proj.weight"] = o_proj_weight

        return True
