# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Any, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn

import vllm.distributed.parallel_state as parallel_state
from vllm.attention.backends.abstract import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import (LlamaAttention,
                                              LlamaDecoderLayer,
                                              LlamaMLP)
from vllm.model_executor.models.utils import (AutoWeightsLoader,
                                              maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

# Add FlashInfer backend detection
try:
    from vllm.v1.attention.backends.flashinfer import FlashInferMetadata
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False
    FlashInferMetadata = None

import arctic_inference.vllm.model_runner as model_runner
from arctic_inference.common.swiftkv.configs import LlamaSwiftKVConfig

logger = init_logger(__name__)


def get_attn_metadata_for_swiftkv():
    fwd_ctx = get_forward_context()
    if fwd_ctx.attn_metadata is None:
        return None
    meta = next(iter(fwd_ctx.attn_metadata.values()))
    assert all(m is meta for m in fwd_ctx.attn_metadata.values()), \
        "All attention metadata should be the same for LlamaSwiftKV."
    return meta


class LlamaSwiftKVAttention(LlamaAttention):

    def __init__(
        self,
        config: LlamaSwiftKVConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
        super().__init__(
            config=config,
            hidden_size=hidden_size,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=bias,
            bias_o_proj=bias_o_proj,
            cache_config=cache_config,
            prefix=prefix,
            attn_type=attn_type)

        self.q_proj_swiftkv = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=self.total_num_heads * self.head_dim,
            bias=bias,
            gather_output=False,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj_swiftkv",
        )

        self.kv_proj_swiftkv = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=0,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.kv_proj_swiftkv",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> torch.Tensor:
        q, _ = self.q_proj_swiftkv(hidden_states)
        q, _ = self.rotary_emb(positions, q, torch.empty_like(k))
        
        # The attention call works the same for both FlashAttention and FlashInfer
        # as they both use the same interface: self.attn(q, k, v)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class LlamaSwiftKVDecoderLayer(nn.Module):

    def __init__(
        self,
        config: LlamaSwiftKVConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
        self.self_attn = LlamaSwiftKVAttention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=attention_bias,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
            prefix=f"{prefix}.mlp",
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        k_states: torch.Tensor,
        v_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            k=k_states,
            v=v_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class LlamaSwiftKVPrefillRunner(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel",
                 prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
        self._model = [model]  # Box it to avoid recursive registration

    @property
    def model(self) -> "LlamaSwiftKVModel":
        return self._model[0]

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:
        hidden_states = self.model.get_input_embeddings(input_ids)
        residual = None
        prefill_layers = self.model.layers[:self.config.num_key_value_layers]
        for idx, layer in enumerate(prefill_layers):
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        sp_size = parallel_state._SP.world_size
        if sp_size > 1 and not model_runner.is_shift_parallel_mode():
            # All-gather across ulysses sequence parallel ranks
            hidden_states = parallel_state._SP.all_gather(hidden_states, dim=0)
            residual = parallel_state._SP.all_gather(residual, dim=0)
            positions = parallel_state._SP.all_gather(positions, dim=0)

        old_mode = model_runner.SP_TP_MODE
        old_tp_group = parallel_state.get_tp_group()
        model_runner.SP_TP_MODE = True
        parallel_state._TP = parallel_state._SP_TP

        # KV projection of all the remaining layers
        swiftkv_hidden_states = (
            self.model.norm_swiftkv(hidden_states + residual))

        k_states = []
        v_states = []
        rotary_emb = self.model.layers[0].self_attn.rotary_emb
        q = torch.empty_like(hidden_states)  # Just temporary buffer
        for layer in self.model.layers[self.config.num_key_value_layers:]:
            kv, _ = layer.self_attn.kv_proj_swiftkv(swiftkv_hidden_states)
            k, v = kv.chunk(2, dim=-1)
            _, k = rotary_emb(positions, q, k)
            k_states.append(k)
            v_states.append(v)
        k_states = torch.cat(k_states, dim=-1)
        v_states = torch.cat(v_states, dim=-1)

        model_runner.SP_TP_MODE = old_mode
        parallel_state._TP = old_tp_group

        return hidden_states, residual, positions, k_states, v_states


@support_torch_compile
class LlamaSwiftKVDecodeRunner(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel",
                 prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
        self._model = [model]  # Box it to avoid recursive registration

    @property
    def model(self) -> "LlamaSwiftKVModel":
        return self._model[0]

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        positions: torch.Tensor,
        k_states: torch.Tensor,
        v_states: torch.Tensor,
    ) -> torch.Tensor:
        # This is a hint for the compiler that v_states and k_states have
        # the same shape so that a single symbolic shape is inferred.
        torch._check(v_states.shape[0] == k_states.shape[0])
        num_layers = (self.config.num_hidden_layers -
                      self.config.num_key_value_layers)
        k_split = torch.chunk(k_states, num_layers, dim=-1)
        v_split = torch.chunk(v_states, num_layers, dim=-1)
        for idx, layer in enumerate(
                self.model.layers[self.config.num_key_value_layers:]):
            hidden_states, residual = layer(
                positions,
                hidden_states,
                k_split[idx],
                v_split[idx],
                residual,
            )
        hidden_states, _ = self.model.norm(hidden_states, residual)
        return hidden_states


class LlamaSwiftKVModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.vllm_config = vllm_config
        config = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=self.quant_config,
        )
        self.layers = torch.nn.ModuleList([
            LlamaDecoderLayer(config=config,
                              cache_config=vllm_config.cache_config,
                              quant_config=vllm_config.quant_config,
                              prefix=f"{prefix}.layers.{idx}")
            for idx in range(config.num_key_value_layers)
        ])
        with model_runner.set_shift_parallel_mode(True):
            self.layers.extend([
                LlamaSwiftKVDecoderLayer(config=config,
                                         cache_config=vllm_config.cache_config,
                                         quant_config=vllm_config.quant_config,
                                         prefix=f"{prefix}.layers.{idx}")
                for idx in range(config.num_key_value_layers,
                                 config.num_hidden_layers)
            ])
            self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        for param in self.layers[config.num_key_value_layers:].parameters():
            param.shift_parallel_mode = True

        self._init_prefill_runner(vllm_config)
        self._init_decode_runner(vllm_config)

        from arctic_inference.py_custom_ops import try_load_torch_library
        self.use_custom_ops = True if try_load_torch_library() else False

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def _init_prefill_runner(self, vllm_config: VllmConfig):
        vllm_config.compilation_config = copy.copy(
            vllm_config.compilation_config)
        vllm_config.compilation_config.inductor_compile_config = (
            vllm_config.compilation_config.inductor_compile_config.copy())
        self.prefill_runner = LlamaSwiftKVPrefillRunner(
            vllm_config=vllm_config, model=self)

    def _init_decode_runner(self, vllm_config: VllmConfig):
        vllm_config.compilation_config = copy.copy(
            vllm_config.compilation_config)
        vllm_config.compilation_config.inductor_compile_config = (
            vllm_config.compilation_config.inductor_compile_config.copy())
        self.decode_runner = LlamaSwiftKVDecodeRunner(
            vllm_config=vllm_config, model=self)

        config = vllm_config.model_config.hf_config
        if vllm_config.compilation_config.cudagraph_capture_sizes:
            self.cuda_graph_max_batch_size = max(
                vllm_config.compilation_config.cudagraph_capture_sizes)
            num_heads = self.layers[-1].self_attn.attn.num_kv_heads
            head_size = self.layers[-1].self_attn.attn.head_size
            num_kv = config.num_hidden_layers - config.num_key_value_layers
            kv_size = num_kv * num_heads * head_size
            self.decode_runner.inputs = {
                "hidden_states": torch.empty(self.cuda_graph_max_batch_size,
                                             config.hidden_size, device="cuda"),
                "residual": torch.empty(self.cuda_graph_max_batch_size,
                                        config.hidden_size, device="cuda"),
                "positions": torch.empty(self.cuda_graph_max_batch_size,
                                         dtype=torch.long, device="cuda"),
                "k_states": torch.empty(self.cuda_graph_max_batch_size,
                                        kv_size, device="cuda"),
                "v_states": torch.empty(self.cuda_graph_max_batch_size,
                                        kv_size, device="cuda"),
            }
        else:
            self.cuda_graph_max_batch_size = 0

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def _fix_flash_attention_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
        # FlashAttention path
        attn_metadata.num_actual_tokens = num_surviving_tokens
        attn_metadata.query_start_loc = torch.searchsorted(
            logits_indices, attn_metadata.query_start_loc, out_int32=True)
        attn_metadata.slot_mapping = attn_metadata.slot_mapping[
            logits_indices]
        
        # TODO: Make cascade attention work with SwiftKV
        attn_metadata.use_cascade = False
        attn_metadata.cu_prefix_query_lens = None
        attn_metadata.prefix_kv_lens = None
        attn_metadata.suffix_kv_lens = None
        attn_metadata.prefix_scheduler_metadata = None

    def _fix_flashinfer_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
        # FlashInfer path
        # 1. get survived requests and get their token counts.
        original_num_tokens = attn_metadata.num_actual_tokens
        token_to_req_id = torch.searchsorted(
            attn_metadata.qo_indptr,
            torch.arange(original_num_tokens,
                         device=logits_indices.device),
            right=True) - 1
        surviving_tokens_flat_req_ids = token_to_req_id[logits_indices]
        surviving_req_ids, surviving_tokens_per_req = torch.unique(surviving_tokens_flat_req_ids, return_counts=True)
        new_num_reqs = surviving_req_ids.numel()

        # 2. classify surviving requests as decode vs prefill
        # decode: exactly 1 token, prefill: > 1 token
        decode_mask = surviving_tokens_per_req == 1
        prefill_mask = surviving_tokens_per_req > 1
        
        decode_req_ids = surviving_req_ids[decode_mask]
        prefill_req_ids = surviving_req_ids[prefill_mask]
        
        new_num_decodes = decode_req_ids.numel()
        new_num_prefills = prefill_req_ids.numel()
        new_num_decode_tokens = decode_mask.sum().item()
        new_num_prefill_tokens = prefill_mask.sum().item()

        # 3. build qo_indptr for surviving requests (decode first, then prefill)
        # Reorder surviving requests: decode first, then prefill
        reordered_req_ids = torch.cat([decode_req_ids, prefill_req_ids])
        reordered_tokens_per_req = torch.cat([
            surviving_tokens_per_req[decode_mask],
            surviving_tokens_per_req[prefill_mask]
        ])
        attn_metadata.qo_indptr = torch.nn.functional.pad(torch.cumsum(reordered_tokens_per_req, dim=0), (1, 0))

        # 4. build paged KV cache metadata for surviving requests
        original_num_pages_per_req = attn_metadata.paged_kv_indptr.diff()
        reordered_num_pages_per_req = original_num_pages_per_req[reordered_req_ids]
        page_indices_start = attn_metadata.paged_kv_indptr[reordered_req_ids]
        page_indices_end = attn_metadata.paged_kv_indptr[reordered_req_ids + 1]

        if new_num_reqs > 0:
            # create page indices for each surviving request
            page_indices_list = []
            for i in range(new_num_reqs):
                start_idx = page_indices_start[i]
                end_idx = page_indices_end[i]
                page_indices_list.append(
                    attn_metadata.paged_kv_indices[start_idx:end_idx])
            attn_metadata.paged_kv_indices = torch.cat(page_indices_list)
        else:
            # no requests survive SwiftKV selection
            attn_metadata.paged_kv_indices = torch.empty(
                0,
                dtype=attn_metadata.paged_kv_indices.dtype,
                device=attn_metadata.paged_kv_indices.device)

        # build paged_kv_indptr for surviving requests
        attn_metadata.paged_kv_indptr = torch.nn.functional.pad(torch.cumsum(reordered_num_pages_per_req, dim=0), (1, 0)).int()
        # update last page lengths for surviving requests
        attn_metadata.paged_kv_last_page_len = attn_metadata.paged_kv_last_page_len[reordered_req_ids]

        # 5. create reordered logits_indices (decode tokens first, then prefill tokens)
        # Map original req_ids to new positions
        old_to_new_req_pos = torch.full((surviving_req_ids.max() + 1,), -1, 
                                       dtype=torch.long, device=logits_indices.device)
        old_to_new_req_pos[reordered_req_ids] = torch.arange(new_num_reqs, device=logits_indices.device)
        
        # Get new request positions for each surviving token
        new_req_positions = old_to_new_req_pos[surviving_tokens_flat_req_ids]
        
        # Sort tokens by new request position to get decode tokens first, then prefill tokens
        sorted_indices = torch.argsort(new_req_positions)
        attn_metadata.swiftkv_inverse_sort_indices = torch.argsort(sorted_indices)
        reordered_logits_indices = logits_indices[sorted_indices]

        # 6. update other metadata fields
        attn_metadata.slot_mapping = attn_metadata.slot_mapping[reordered_logits_indices]
        attn_metadata.num_actual_tokens = num_surviving_tokens
        attn_metadata.num_decodes = new_num_decodes
        attn_metadata.num_prefills = new_num_prefills
        attn_metadata.num_decode_tokens = new_num_decode_tokens
        attn_metadata.num_prefill_tokens = new_num_prefill_tokens
        attn_metadata.use_cascade = False

        # cascade attention fields
        attn_metadata.shared_qo_indptr = None
        attn_metadata.shared_kv_page_indptr = None
        attn_metadata.shared_kv_page_indices = None
        attn_metadata.shared_kv_last_page_len = None
        attn_metadata.cascade_wrapper = None

        # 7. re-plan the FlashInfer attention wrappers with new metadata
        impl = self.layers[-1].self_attn.attn.impl
        
        if attn_metadata.decode_wrapper and new_num_decodes > 0:
            attn_metadata.decode_wrapper.plan(
                attn_metadata.paged_kv_indptr[:new_num_decodes + 1],
                attn_metadata.paged_kv_indices,
                attn_metadata.paged_kv_last_page_len[:new_num_decodes],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                pos_encoding_mode="NONE",
                sm_scale=impl.scale,
                window_left=impl.sliding_window[0],
                logits_soft_cap=impl.logits_soft_cap or 0.0,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.data_type,
                )
        else:
            attn_metadata.decode_wrapper = None
        
        # Plan prefill wrapper if we have prefill requests
        if attn_metadata.prefill_wrapper and new_num_prefills > 0:
            # Prefill starts after decode requests
            prefill_start = new_num_decodes
            qo_indptr_prefill = attn_metadata.qo_indptr[prefill_start:] - attn_metadata.qo_indptr[prefill_start]
            attn_metadata.prefill_wrapper.plan(
                qo_indptr_prefill,
                attn_metadata.paged_kv_indptr[prefill_start:],
                attn_metadata.paged_kv_indices,
                attn_metadata.paged_kv_last_page_len[prefill_start:],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                causal=True,
                sm_scale=impl.scale,
                window_left=impl.sliding_window[0],
                logits_soft_cap=impl.logits_soft_cap or 0.0,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.data_type,
            )
        else:
            attn_metadata.prefill_wrapper = None
        
        return reordered_logits_indices

    def swiftkv_select(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        positions: torch.Tensor,
        k_states: torch.Tensor,
        v_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:
        forward_context: ForwardContext = get_forward_context()
        attn_metadata = get_attn_metadata_for_swiftkv()
        if attn_metadata is None:
            # Graph capture or profiling mode.
            if hidden_states.shape[0] <= self.cuda_graph_max_batch_size:
                # Return the preallocated buffers so cuda graph is captured
                # correctly.
                inputs = self.decode_runner.inputs
                batch_size = hidden_states.shape[0]
                padded_size = self.vllm_config.pad_for_cudagraph(batch_size)
                return (inputs["hidden_states"][:padded_size],
                        inputs["residual"][:padded_size],
                        inputs["positions"][:padded_size],
                        inputs["k_states"][:padded_size],
                        inputs["v_states"][:padded_size])
            return hidden_states, residual, positions, k_states, v_states

        if self.use_custom_ops:
            key_caches : List[torch.Tensor] = []
            value_caches : List[torch.Tensor] = []
            k_scales : List[torch.Tensor] = []
            v_scales : List[torch.Tensor] = []
            num_heads = self.layers[-1].self_attn.attn.num_kv_heads
            head_size = self.layers[-1].self_attn.attn.head_size
            for idx, layer in enumerate(
                    self.layers[self.config.num_key_value_layers:]):
                attn = layer.self_attn.attn
                kv_cache = attn.kv_cache[forward_context.virtual_engine]
                if kv_cache.numel():
                    # different cache layouts
                    if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
                        # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
                        key_caches.append(kv_cache[:, 0])
                        value_caches.append(kv_cache[:, 1])
                    else:
                        # FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size]
                        key_caches.append(kv_cache[0])
                        value_caches.append(kv_cache[1])
                    k_scales.append(attn._k_scale)
                    v_scales.append(attn._v_scale)

            if len(key_caches) > 0:
                from arctic_inference.py_custom_ops import reshape_and_cache_flash_bulk
                reshape_and_cache_flash_bulk(
                    k_states, v_states, key_caches, value_caches,
                    attn_metadata.slot_mapping, attn.kv_cache_dtype, k_scales,
                    v_scales, num_heads, head_size)
        else:
            num_layers = (self.config.num_hidden_layers - self.config.num_key_value_layers)

            k_split = k_states.chunk(num_layers, dim=-1)
            v_split = v_states.chunk(num_layers, dim=-1)

            for idx, layer in enumerate(
                    self.layers[self.config.num_key_value_layers:]):
                attn = layer.self_attn.attn
                kv_cache = attn.kv_cache[forward_context.virtual_engine]
                if kv_cache.numel():
                    if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
                        # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
                        k_cache, v_cache = kv_cache.unbind(1)
                    else:
                        # FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size]
                        k_cache, v_cache = kv_cache.unbind(0)

                    torch.ops._C_cache_ops.reshape_and_cache_flash(
                        k_split[idx].view(-1, attn.num_kv_heads, attn.head_size),
                        v_split[idx].view(-1, attn.num_kv_heads, attn.head_size),
                        k_cache,
                        v_cache,
                        attn_metadata.slot_mapping,
                        attn.kv_cache_dtype,
                        attn._k_scale,
                        attn._v_scale,
                    )

        logits_indices = attn_metadata.swiftkv_logits_indices
        num_surviving_tokens = logits_indices.numel()

        if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
            # Handle FlashInfer metadata
            final_logits_indices = self._fix_flashinfer_metadata(attn_metadata, logits_indices, num_surviving_tokens)
        else:
            # Handle FlashAttention metadata
            self._fix_flash_attention_metadata(attn_metadata, logits_indices, num_surviving_tokens)
            final_logits_indices = logits_indices

        def index_fn(buffer_name: str, tensor: torch.Tensor,
                     indices: torch.LongTensor) -> torch.Tensor:
            # If the batch size is smaller than the maximum batch size
            # for cuda graph, we can use the preallocated buffer.
            batch_size = indices.numel()
            if batch_size > 0 and batch_size <= self.cuda_graph_max_batch_size:
                buffer = self.decode_runner.inputs[buffer_name]
                torch.index_select(tensor, 0, indices, out=buffer[:batch_size])
                padded_size = self.vllm_config.pad_for_cudagraph(batch_size)
                return buffer[:padded_size]
            return tensor.index_select(0, indices)

        return (index_fn("hidden_states", hidden_states, final_logits_indices),
                index_fn("residual", residual, final_logits_indices),
                index_fn("positions", positions, final_logits_indices),
                index_fn("k_states", k_states, final_logits_indices),
                index_fn("v_states", v_states, final_logits_indices))

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
    ) -> torch.Tensor:

        hidden_states, residual, positions, k_states, v_states = (
            self.prefill_runner(input_ids, positions))

        orig_hidden_states = hidden_states
        hidden_states, residual, positions, k_states, v_states = (
            self.swiftkv_select(
                hidden_states,
                residual,
                positions,
                k_states,
                v_states))

        with model_runner.set_shift_parallel_mode(True):
            hidden_states = self.decode_runner(
                hidden_states,
                residual,
                positions,
                k_states,
                v_states,
            )

        attn_metadata = get_attn_metadata_for_swiftkv()
        if attn_metadata is not None:
            logits_indices = attn_metadata.swiftkv_logits_indices
            batch_size = logits_indices.numel()
            
            if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
                inverse_sort_indices = attn_metadata.swiftkv_inverse_sort_indices
                orig_hidden_states[logits_indices] = hidden_states[inverse_sort_indices][:batch_size]
            else:
                orig_hidden_states[logits_indices] = hidden_states[:batch_size]

        return orig_hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj.", ".q_proj.", "q"),
            (".qkv_proj.", ".k_proj.", "k"),
            (".qkv_proj.", ".v_proj.", "v"),
            (".gate_up_proj.", ".gate_proj.", 0),
            (".gate_up_proj.", ".up_proj.", 1),
            (".kv_proj_swiftkv.", ".k_proj_swiftkv.", "k"),
            (".kv_proj_swiftkv.", ".v_proj_swiftkv.", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                use_shift_mode = getattr(param, "shift_parallel_mode", None)
                with model_runner.set_shift_parallel_mode(use_shift_mode):
                    weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                use_shift_mode = getattr(param, "shift_parallel_mode", None)
                with model_runner.set_shift_parallel_mode(use_shift_mode):
                    weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                use_shift_mode = getattr(param, "shift_parallel_mode", None)
                with model_runner.set_shift_parallel_mode(use_shift_mode):
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class LlamaSwiftKVForCausalLM(nn.Module):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"],
    }

    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config

        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))

        self.unpadded_vocab_size = config.vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(
                self.model.embed_tokens)

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size,
                                                logit_scale)

    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = ""):
        return LlamaSwiftKVModel(vllm_config=vllm_config, prefix=prefix)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        assert intermediate_tensors is None and inputs_embeds is None
        model_output = self.model(input_ids, positions)
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)