arctic_inference/vllm/swiftkv/llama_swiftkv.py (703 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Any, Iterable, List, Optional, Tuple, Union import torch from torch import nn import vllm.distributed.parallel_state as parallel_state from vllm.attention.backends.abstract import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.llama import (LlamaAttention, LlamaDecoderLayer, LlamaMLP) from vllm.model_executor.models.utils import (AutoWeightsLoader, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors # Add FlashInfer backend detection try: from vllm.v1.attention.backends.flashinfer import FlashInferMetadata FLASHINFER_AVAILABLE = True except ImportError: FLASHINFER_AVAILABLE = False FlashInferMetadata = None import arctic_inference.vllm.model_runner as model_runner from arctic_inference.common.swiftkv.configs import LlamaSwiftKVConfig logger = init_logger(__name__) def get_attn_metadata_for_swiftkv(): fwd_ctx = get_forward_context() if fwd_ctx.attn_metadata is None: return None meta = next(iter(fwd_ctx.attn_metadata.values())) assert all(m is meta for m in fwd_ctx.attn_metadata.values()), \ "All attention metadata should be the same for LlamaSwiftKV." return meta class LlamaSwiftKVAttention(LlamaAttention): def __init__( self, config: LlamaSwiftKVConfig, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, bias_o_proj: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: super().__init__( config=config, hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=bias, bias_o_proj=bias_o_proj, cache_config=cache_config, prefix=prefix, attn_type=attn_type) self.q_proj_swiftkv = ColumnParallelLinear( input_size=hidden_size, output_size=self.total_num_heads * self.head_dim, bias=bias, gather_output=False, quant_config=quant_config, prefix=f"{prefix}.q_proj_swiftkv", ) self.kv_proj_swiftkv = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=0, total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, prefix=f"{prefix}.kv_proj_swiftkv", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: q, _ = self.q_proj_swiftkv(hidden_states) q, _ = self.rotary_emb(positions, q, torch.empty_like(k)) # The attention call works the same for both FlashAttention and FlashInfer # as they both use the same interface: self.attn(q, k, v) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class LlamaSwiftKVDecoderLayer(nn.Module): def __init__( self, config: LlamaSwiftKVConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaSwiftKVAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, k_states: torch.Tensor, v_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, k=k_states, v=v_states, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class LlamaSwiftKVPrefillRunner(nn.Module): def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel", prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self._model = [model] # Box it to avoid recursive registration @property def model(self) -> "LlamaSwiftKVModel": return self._model[0] def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = self.model.get_input_embeddings(input_ids) residual = None prefill_layers = self.model.layers[:self.config.num_key_value_layers] for idx, layer in enumerate(prefill_layers): hidden_states, residual = layer( positions, hidden_states, residual, ) sp_size = parallel_state._SP.world_size if sp_size > 1 and not model_runner.is_shift_parallel_mode(): # All-gather across ulysses sequence parallel ranks hidden_states = parallel_state._SP.all_gather(hidden_states, dim=0) residual = parallel_state._SP.all_gather(residual, dim=0) positions = parallel_state._SP.all_gather(positions, dim=0) old_mode = model_runner.SP_TP_MODE old_tp_group = parallel_state.get_tp_group() model_runner.SP_TP_MODE = True parallel_state._TP = parallel_state._SP_TP # KV projection of all the remaining layers swiftkv_hidden_states = ( self.model.norm_swiftkv(hidden_states + residual)) k_states = [] v_states = [] rotary_emb = self.model.layers[0].self_attn.rotary_emb q = torch.empty_like(hidden_states) # Just temporary buffer for layer in self.model.layers[self.config.num_key_value_layers:]: kv, _ = layer.self_attn.kv_proj_swiftkv(swiftkv_hidden_states) k, v = kv.chunk(2, dim=-1) _, k = rotary_emb(positions, q, k) k_states.append(k) v_states.append(v) k_states = torch.cat(k_states, dim=-1) v_states = torch.cat(v_states, dim=-1) model_runner.SP_TP_MODE = old_mode parallel_state._TP = old_tp_group return hidden_states, residual, positions, k_states, v_states @support_torch_compile class LlamaSwiftKVDecodeRunner(nn.Module): def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel", prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self._model = [model] # Box it to avoid recursive registration @property def model(self) -> "LlamaSwiftKVModel": return self._model[0] def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor, k_states: torch.Tensor, v_states: torch.Tensor, ) -> torch.Tensor: # This is a hint for the compiler that v_states and k_states have # the same shape so that a single symbolic shape is inferred. torch._check(v_states.shape[0] == k_states.shape[0]) num_layers = (self.config.num_hidden_layers - self.config.num_key_value_layers) k_split = torch.chunk(k_states, num_layers, dim=-1) v_split = torch.chunk(v_states, num_layers, dim=-1) for idx, layer in enumerate( self.model.layers[self.config.num_key_value_layers:]): hidden_states, residual = layer( positions, hidden_states, k_split[idx], v_split[idx], residual, ) hidden_states, _ = self.model.norm(hidden_states, residual) return hidden_states class LlamaSwiftKVModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.vllm_config = vllm_config config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=self.quant_config, ) self.layers = torch.nn.ModuleList([ LlamaDecoderLayer(config=config, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, prefix=f"{prefix}.layers.{idx}") for idx in range(config.num_key_value_layers) ]) with model_runner.set_shift_parallel_mode(True): self.layers.extend([ LlamaSwiftKVDecoderLayer(config=config, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, prefix=f"{prefix}.layers.{idx}") for idx in range(config.num_key_value_layers, config.num_hidden_layers) ]) self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for param in self.layers[config.num_key_value_layers:].parameters(): param.shift_parallel_mode = True self._init_prefill_runner(vllm_config) self._init_decode_runner(vllm_config) from arctic_inference.py_custom_ops import try_load_torch_library self.use_custom_ops = True if try_load_torch_library() else False def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def _init_prefill_runner(self, vllm_config: VllmConfig): vllm_config.compilation_config = copy.copy( vllm_config.compilation_config) vllm_config.compilation_config.inductor_compile_config = ( vllm_config.compilation_config.inductor_compile_config.copy()) self.prefill_runner = LlamaSwiftKVPrefillRunner( vllm_config=vllm_config, model=self) def _init_decode_runner(self, vllm_config: VllmConfig): vllm_config.compilation_config = copy.copy( vllm_config.compilation_config) vllm_config.compilation_config.inductor_compile_config = ( vllm_config.compilation_config.inductor_compile_config.copy()) self.decode_runner = LlamaSwiftKVDecodeRunner( vllm_config=vllm_config, model=self) config = vllm_config.model_config.hf_config if vllm_config.compilation_config.cudagraph_capture_sizes: self.cuda_graph_max_batch_size = max( vllm_config.compilation_config.cudagraph_capture_sizes) num_heads = self.layers[-1].self_attn.attn.num_kv_heads head_size = self.layers[-1].self_attn.attn.head_size num_kv = config.num_hidden_layers - config.num_key_value_layers kv_size = num_kv * num_heads * head_size self.decode_runner.inputs = { "hidden_states": torch.empty(self.cuda_graph_max_batch_size, config.hidden_size, device="cuda"), "residual": torch.empty(self.cuda_graph_max_batch_size, config.hidden_size, device="cuda"), "positions": torch.empty(self.cuda_graph_max_batch_size, dtype=torch.long, device="cuda"), "k_states": torch.empty(self.cuda_graph_max_batch_size, kv_size, device="cuda"), "v_states": torch.empty(self.cuda_graph_max_batch_size, kv_size, device="cuda"), } else: self.cuda_graph_max_batch_size = 0 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def _fix_flash_attention_metadata(self, attn_metadata, logits_indices, num_surviving_tokens): # FlashAttention path attn_metadata.num_actual_tokens = num_surviving_tokens attn_metadata.query_start_loc = torch.searchsorted( logits_indices, attn_metadata.query_start_loc, out_int32=True) attn_metadata.slot_mapping = attn_metadata.slot_mapping[ logits_indices] # TODO: Make cascade attention work with SwiftKV attn_metadata.use_cascade = False attn_metadata.cu_prefix_query_lens = None attn_metadata.prefix_kv_lens = None attn_metadata.suffix_kv_lens = None attn_metadata.prefix_scheduler_metadata = None def _fix_flashinfer_metadata(self, attn_metadata, logits_indices, num_surviving_tokens): # FlashInfer path # 1. get survived requests and get their token counts. original_num_tokens = attn_metadata.num_actual_tokens token_to_req_id = torch.searchsorted( attn_metadata.qo_indptr, torch.arange(original_num_tokens, device=logits_indices.device), right=True) - 1 surviving_tokens_flat_req_ids = token_to_req_id[logits_indices] surviving_req_ids, surviving_tokens_per_req = torch.unique(surviving_tokens_flat_req_ids, return_counts=True) new_num_reqs = surviving_req_ids.numel() # 2. classify surviving requests as decode vs prefill # decode: exactly 1 token, prefill: > 1 token decode_mask = surviving_tokens_per_req == 1 prefill_mask = surviving_tokens_per_req > 1 decode_req_ids = surviving_req_ids[decode_mask] prefill_req_ids = surviving_req_ids[prefill_mask] new_num_decodes = decode_req_ids.numel() new_num_prefills = prefill_req_ids.numel() new_num_decode_tokens = decode_mask.sum().item() new_num_prefill_tokens = prefill_mask.sum().item() # 3. build qo_indptr for surviving requests (decode first, then prefill) # Reorder surviving requests: decode first, then prefill reordered_req_ids = torch.cat([decode_req_ids, prefill_req_ids]) reordered_tokens_per_req = torch.cat([ surviving_tokens_per_req[decode_mask], surviving_tokens_per_req[prefill_mask] ]) attn_metadata.qo_indptr = torch.nn.functional.pad(torch.cumsum(reordered_tokens_per_req, dim=0), (1, 0)) # 4. build paged KV cache metadata for surviving requests original_num_pages_per_req = attn_metadata.paged_kv_indptr.diff() reordered_num_pages_per_req = original_num_pages_per_req[reordered_req_ids] page_indices_start = attn_metadata.paged_kv_indptr[reordered_req_ids] page_indices_end = attn_metadata.paged_kv_indptr[reordered_req_ids + 1] if new_num_reqs > 0: # create page indices for each surviving request page_indices_list = [] for i in range(new_num_reqs): start_idx = page_indices_start[i] end_idx = page_indices_end[i] page_indices_list.append( attn_metadata.paged_kv_indices[start_idx:end_idx]) attn_metadata.paged_kv_indices = torch.cat(page_indices_list) else: # no requests survive SwiftKV selection attn_metadata.paged_kv_indices = torch.empty( 0, dtype=attn_metadata.paged_kv_indices.dtype, device=attn_metadata.paged_kv_indices.device) # build paged_kv_indptr for surviving requests attn_metadata.paged_kv_indptr = torch.nn.functional.pad(torch.cumsum(reordered_num_pages_per_req, dim=0), (1, 0)).int() # update last page lengths for surviving requests attn_metadata.paged_kv_last_page_len = attn_metadata.paged_kv_last_page_len[reordered_req_ids] # 5. create reordered logits_indices (decode tokens first, then prefill tokens) # Map original req_ids to new positions old_to_new_req_pos = torch.full((surviving_req_ids.max() + 1,), -1, dtype=torch.long, device=logits_indices.device) old_to_new_req_pos[reordered_req_ids] = torch.arange(new_num_reqs, device=logits_indices.device) # Get new request positions for each surviving token new_req_positions = old_to_new_req_pos[surviving_tokens_flat_req_ids] # Sort tokens by new request position to get decode tokens first, then prefill tokens sorted_indices = torch.argsort(new_req_positions) attn_metadata.swiftkv_inverse_sort_indices = torch.argsort(sorted_indices) reordered_logits_indices = logits_indices[sorted_indices] # 6. update other metadata fields attn_metadata.slot_mapping = attn_metadata.slot_mapping[reordered_logits_indices] attn_metadata.num_actual_tokens = num_surviving_tokens attn_metadata.num_decodes = new_num_decodes attn_metadata.num_prefills = new_num_prefills attn_metadata.num_decode_tokens = new_num_decode_tokens attn_metadata.num_prefill_tokens = new_num_prefill_tokens attn_metadata.use_cascade = False # cascade attention fields attn_metadata.shared_qo_indptr = None attn_metadata.shared_kv_page_indptr = None attn_metadata.shared_kv_page_indices = None attn_metadata.shared_kv_last_page_len = None attn_metadata.cascade_wrapper = None # 7. re-plan the FlashInfer attention wrappers with new metadata impl = self.layers[-1].self_attn.attn.impl if attn_metadata.decode_wrapper and new_num_decodes > 0: attn_metadata.decode_wrapper.plan( attn_metadata.paged_kv_indptr[:new_num_decodes + 1], attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len[:new_num_decodes], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, attn_metadata.page_size, pos_encoding_mode="NONE", sm_scale=impl.scale, window_left=impl.sliding_window[0], logits_soft_cap=impl.logits_soft_cap or 0.0, q_data_type=attn_metadata.q_data_type, kv_data_type=attn_metadata.data_type, ) else: attn_metadata.decode_wrapper = None # Plan prefill wrapper if we have prefill requests if attn_metadata.prefill_wrapper and new_num_prefills > 0: # Prefill starts after decode requests prefill_start = new_num_decodes qo_indptr_prefill = attn_metadata.qo_indptr[prefill_start:] - attn_metadata.qo_indptr[prefill_start] attn_metadata.prefill_wrapper.plan( qo_indptr_prefill, attn_metadata.paged_kv_indptr[prefill_start:], attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len[prefill_start:], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, attn_metadata.page_size, causal=True, sm_scale=impl.scale, window_left=impl.sliding_window[0], logits_soft_cap=impl.logits_soft_cap or 0.0, q_data_type=attn_metadata.q_data_type, kv_data_type=attn_metadata.data_type, ) else: attn_metadata.prefill_wrapper = None return reordered_logits_indices def swiftkv_select( self, hidden_states: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor, k_states: torch.Tensor, v_states: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: forward_context: ForwardContext = get_forward_context() attn_metadata = get_attn_metadata_for_swiftkv() if attn_metadata is None: # Graph capture or profiling mode. if hidden_states.shape[0] <= self.cuda_graph_max_batch_size: # Return the preallocated buffers so cuda graph is captured # correctly. inputs = self.decode_runner.inputs batch_size = hidden_states.shape[0] padded_size = self.vllm_config.pad_for_cudagraph(batch_size) return (inputs["hidden_states"][:padded_size], inputs["residual"][:padded_size], inputs["positions"][:padded_size], inputs["k_states"][:padded_size], inputs["v_states"][:padded_size]) return hidden_states, residual, positions, k_states, v_states if self.use_custom_ops: key_caches : List[torch.Tensor] = [] value_caches : List[torch.Tensor] = [] k_scales : List[torch.Tensor] = [] v_scales : List[torch.Tensor] = [] num_heads = self.layers[-1].self_attn.attn.num_kv_heads head_size = self.layers[-1].self_attn.attn.head_size for idx, layer in enumerate( self.layers[self.config.num_key_value_layers:]): attn = layer.self_attn.attn kv_cache = attn.kv_cache[forward_context.virtual_engine] if kv_cache.numel(): # different cache layouts if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size] key_caches.append(kv_cache[:, 0]) value_caches.append(kv_cache[:, 1]) else: # FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size] key_caches.append(kv_cache[0]) value_caches.append(kv_cache[1]) k_scales.append(attn._k_scale) v_scales.append(attn._v_scale) if len(key_caches) > 0: from arctic_inference.py_custom_ops import reshape_and_cache_flash_bulk reshape_and_cache_flash_bulk( k_states, v_states, key_caches, value_caches, attn_metadata.slot_mapping, attn.kv_cache_dtype, k_scales, v_scales, num_heads, head_size) else: num_layers = (self.config.num_hidden_layers - self.config.num_key_value_layers) k_split = k_states.chunk(num_layers, dim=-1) v_split = v_states.chunk(num_layers, dim=-1) for idx, layer in enumerate( self.layers[self.config.num_key_value_layers:]): attn = layer.self_attn.attn kv_cache = attn.kv_cache[forward_context.virtual_engine] if kv_cache.numel(): if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size] k_cache, v_cache = kv_cache.unbind(1) else: # FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size] k_cache, v_cache = kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( k_split[idx].view(-1, attn.num_kv_heads, attn.head_size), v_split[idx].view(-1, attn.num_kv_heads, attn.head_size), k_cache, v_cache, attn_metadata.slot_mapping, attn.kv_cache_dtype, attn._k_scale, attn._v_scale, ) logits_indices = attn_metadata.swiftkv_logits_indices num_surviving_tokens = logits_indices.numel() if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): # Handle FlashInfer metadata final_logits_indices = self._fix_flashinfer_metadata(attn_metadata, logits_indices, num_surviving_tokens) else: # Handle FlashAttention metadata self._fix_flash_attention_metadata(attn_metadata, logits_indices, num_surviving_tokens) final_logits_indices = logits_indices def index_fn(buffer_name: str, tensor: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor: # If the batch size is smaller than the maximum batch size # for cuda graph, we can use the preallocated buffer. batch_size = indices.numel() if batch_size > 0 and batch_size <= self.cuda_graph_max_batch_size: buffer = self.decode_runner.inputs[buffer_name] torch.index_select(tensor, 0, indices, out=buffer[:batch_size]) padded_size = self.vllm_config.pad_for_cudagraph(batch_size) return buffer[:padded_size] return tensor.index_select(0, indices) return (index_fn("hidden_states", hidden_states, final_logits_indices), index_fn("residual", residual, final_logits_indices), index_fn("positions", positions, final_logits_indices), index_fn("k_states", k_states, final_logits_indices), index_fn("v_states", v_states, final_logits_indices)) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, ) -> torch.Tensor: hidden_states, residual, positions, k_states, v_states = ( self.prefill_runner(input_ids, positions)) orig_hidden_states = hidden_states hidden_states, residual, positions, k_states, v_states = ( self.swiftkv_select( hidden_states, residual, positions, k_states, v_states)) with model_runner.set_shift_parallel_mode(True): hidden_states = self.decode_runner( hidden_states, residual, positions, k_states, v_states, ) attn_metadata = get_attn_metadata_for_swiftkv() if attn_metadata is not None: logits_indices = attn_metadata.swiftkv_logits_indices batch_size = logits_indices.numel() if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata): inverse_sort_indices = attn_metadata.swiftkv_inverse_sort_indices orig_hidden_states[logits_indices] = hidden_states[inverse_sort_indices][:batch_size] else: orig_hidden_states[logits_indices] = hidden_states[:batch_size] return orig_hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj.", ".q_proj.", "q"), (".qkv_proj.", ".k_proj.", "k"), (".qkv_proj.", ".v_proj.", "v"), (".gate_up_proj.", ".gate_proj.", 0), (".gate_up_proj.", ".up_proj.", 1), (".kv_proj_swiftkv.", ".k_proj_swiftkv.", "k"), (".kv_proj_swiftkv.", ".v_proj_swiftkv.", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]) use_shift_mode = getattr(param, "shift_parallel_mode", None) with model_runner.set_shift_parallel_mode(use_shift_mode): weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader use_shift_mode = getattr(param, "shift_parallel_mode", None) with model_runner.set_shift_parallel_mode(use_shift_mode): weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) use_shift_mode = getattr(param, "shift_parallel_mode", None) with model_runner.set_shift_parallel_mode(use_shift_mode): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class LlamaSwiftKVForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], "kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.model = self._init_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return LlamaSwiftKVModel(vllm_config=vllm_config, prefix=prefix) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: assert intermediate_tensors is None and inputs_embeds is None model_output = self.model(input_ids, positions) return model_output def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights)