arctic_inference/vllm/swiftkv/llama_swiftkv.py (703 lines of code) (raw):
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Any, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
import vllm.distributed.parallel_state as parallel_state
from vllm.attention.backends.abstract import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import (LlamaAttention,
LlamaDecoderLayer,
LlamaMLP)
from vllm.model_executor.models.utils import (AutoWeightsLoader,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
# Add FlashInfer backend detection
try:
from vllm.v1.attention.backends.flashinfer import FlashInferMetadata
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
FlashInferMetadata = None
import arctic_inference.vllm.model_runner as model_runner
from arctic_inference.common.swiftkv.configs import LlamaSwiftKVConfig
logger = init_logger(__name__)
def get_attn_metadata_for_swiftkv():
fwd_ctx = get_forward_context()
if fwd_ctx.attn_metadata is None:
return None
meta = next(iter(fwd_ctx.attn_metadata.values()))
assert all(m is meta for m in fwd_ctx.attn_metadata.values()), \
"All attention metadata should be the same for LlamaSwiftKV."
return meta
class LlamaSwiftKVAttention(LlamaAttention):
def __init__(
self,
config: LlamaSwiftKVConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(
config=config,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=prefix,
attn_type=attn_type)
self.q_proj_swiftkv = ColumnParallelLinear(
input_size=hidden_size,
output_size=self.total_num_heads * self.head_dim,
bias=bias,
gather_output=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj_swiftkv",
)
self.kv_proj_swiftkv = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj_swiftkv",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
q, _ = self.q_proj_swiftkv(hidden_states)
q, _ = self.rotary_emb(positions, q, torch.empty_like(k))
# The attention call works the same for both FlashAttention and FlashInfer
# as they both use the same interface: self.attn(q, k, v)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class LlamaSwiftKVDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaSwiftKVConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaSwiftKVAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
k_states: torch.Tensor,
v_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
k=k_states,
v=v_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class LlamaSwiftKVPrefillRunner(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel",
prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self._model = [model] # Box it to avoid recursive registration
@property
def model(self) -> "LlamaSwiftKVModel":
return self._model[0]
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
hidden_states = self.model.get_input_embeddings(input_ids)
residual = None
prefill_layers = self.model.layers[:self.config.num_key_value_layers]
for idx, layer in enumerate(prefill_layers):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
sp_size = parallel_state._SP.world_size
if sp_size > 1 and not model_runner.is_shift_parallel_mode():
# All-gather across ulysses sequence parallel ranks
hidden_states = parallel_state._SP.all_gather(hidden_states, dim=0)
residual = parallel_state._SP.all_gather(residual, dim=0)
positions = parallel_state._SP.all_gather(positions, dim=0)
old_mode = model_runner.SP_TP_MODE
old_tp_group = parallel_state.get_tp_group()
model_runner.SP_TP_MODE = True
parallel_state._TP = parallel_state._SP_TP
# KV projection of all the remaining layers
swiftkv_hidden_states = (
self.model.norm_swiftkv(hidden_states + residual))
k_states = []
v_states = []
rotary_emb = self.model.layers[0].self_attn.rotary_emb
q = torch.empty_like(hidden_states) # Just temporary buffer
for layer in self.model.layers[self.config.num_key_value_layers:]:
kv, _ = layer.self_attn.kv_proj_swiftkv(swiftkv_hidden_states)
k, v = kv.chunk(2, dim=-1)
_, k = rotary_emb(positions, q, k)
k_states.append(k)
v_states.append(v)
k_states = torch.cat(k_states, dim=-1)
v_states = torch.cat(v_states, dim=-1)
model_runner.SP_TP_MODE = old_mode
parallel_state._TP = old_tp_group
return hidden_states, residual, positions, k_states, v_states
@support_torch_compile
class LlamaSwiftKVDecodeRunner(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, model: "LlamaSwiftKVModel",
prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self._model = [model] # Box it to avoid recursive registration
@property
def model(self) -> "LlamaSwiftKVModel":
return self._model[0]
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
k_states: torch.Tensor,
v_states: torch.Tensor,
) -> torch.Tensor:
# This is a hint for the compiler that v_states and k_states have
# the same shape so that a single symbolic shape is inferred.
torch._check(v_states.shape[0] == k_states.shape[0])
num_layers = (self.config.num_hidden_layers -
self.config.num_key_value_layers)
k_split = torch.chunk(k_states, num_layers, dim=-1)
v_split = torch.chunk(v_states, num_layers, dim=-1)
for idx, layer in enumerate(
self.model.layers[self.config.num_key_value_layers:]):
hidden_states, residual = layer(
positions,
hidden_states,
k_split[idx],
v_split[idx],
residual,
)
hidden_states, _ = self.model.norm(hidden_states, residual)
return hidden_states
class LlamaSwiftKVModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=self.quant_config,
)
self.layers = torch.nn.ModuleList([
LlamaDecoderLayer(config=config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.layers.{idx}")
for idx in range(config.num_key_value_layers)
])
with model_runner.set_shift_parallel_mode(True):
self.layers.extend([
LlamaSwiftKVDecoderLayer(config=config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.layers.{idx}")
for idx in range(config.num_key_value_layers,
config.num_hidden_layers)
])
self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for param in self.layers[config.num_key_value_layers:].parameters():
param.shift_parallel_mode = True
self._init_prefill_runner(vllm_config)
self._init_decode_runner(vllm_config)
from arctic_inference.py_custom_ops import try_load_torch_library
self.use_custom_ops = True if try_load_torch_library() else False
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def _init_prefill_runner(self, vllm_config: VllmConfig):
vllm_config.compilation_config = copy.copy(
vllm_config.compilation_config)
vllm_config.compilation_config.inductor_compile_config = (
vllm_config.compilation_config.inductor_compile_config.copy())
self.prefill_runner = LlamaSwiftKVPrefillRunner(
vllm_config=vllm_config, model=self)
def _init_decode_runner(self, vllm_config: VllmConfig):
vllm_config.compilation_config = copy.copy(
vllm_config.compilation_config)
vllm_config.compilation_config.inductor_compile_config = (
vllm_config.compilation_config.inductor_compile_config.copy())
self.decode_runner = LlamaSwiftKVDecodeRunner(
vllm_config=vllm_config, model=self)
config = vllm_config.model_config.hf_config
if vllm_config.compilation_config.cudagraph_capture_sizes:
self.cuda_graph_max_batch_size = max(
vllm_config.compilation_config.cudagraph_capture_sizes)
num_heads = self.layers[-1].self_attn.attn.num_kv_heads
head_size = self.layers[-1].self_attn.attn.head_size
num_kv = config.num_hidden_layers - config.num_key_value_layers
kv_size = num_kv * num_heads * head_size
self.decode_runner.inputs = {
"hidden_states": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size, device="cuda"),
"residual": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size, device="cuda"),
"positions": torch.empty(self.cuda_graph_max_batch_size,
dtype=torch.long, device="cuda"),
"k_states": torch.empty(self.cuda_graph_max_batch_size,
kv_size, device="cuda"),
"v_states": torch.empty(self.cuda_graph_max_batch_size,
kv_size, device="cuda"),
}
else:
self.cuda_graph_max_batch_size = 0
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def _fix_flash_attention_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
# FlashAttention path
attn_metadata.num_actual_tokens = num_surviving_tokens
attn_metadata.query_start_loc = torch.searchsorted(
logits_indices, attn_metadata.query_start_loc, out_int32=True)
attn_metadata.slot_mapping = attn_metadata.slot_mapping[
logits_indices]
# TODO: Make cascade attention work with SwiftKV
attn_metadata.use_cascade = False
attn_metadata.cu_prefix_query_lens = None
attn_metadata.prefix_kv_lens = None
attn_metadata.suffix_kv_lens = None
attn_metadata.prefix_scheduler_metadata = None
def _fix_flashinfer_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
# FlashInfer path
# 1. get survived requests and get their token counts.
original_num_tokens = attn_metadata.num_actual_tokens
token_to_req_id = torch.searchsorted(
attn_metadata.qo_indptr,
torch.arange(original_num_tokens,
device=logits_indices.device),
right=True) - 1
surviving_tokens_flat_req_ids = token_to_req_id[logits_indices]
surviving_req_ids, surviving_tokens_per_req = torch.unique(surviving_tokens_flat_req_ids, return_counts=True)
new_num_reqs = surviving_req_ids.numel()
# 2. classify surviving requests as decode vs prefill
# decode: exactly 1 token, prefill: > 1 token
decode_mask = surviving_tokens_per_req == 1
prefill_mask = surviving_tokens_per_req > 1
decode_req_ids = surviving_req_ids[decode_mask]
prefill_req_ids = surviving_req_ids[prefill_mask]
new_num_decodes = decode_req_ids.numel()
new_num_prefills = prefill_req_ids.numel()
new_num_decode_tokens = decode_mask.sum().item()
new_num_prefill_tokens = prefill_mask.sum().item()
# 3. build qo_indptr for surviving requests (decode first, then prefill)
# Reorder surviving requests: decode first, then prefill
reordered_req_ids = torch.cat([decode_req_ids, prefill_req_ids])
reordered_tokens_per_req = torch.cat([
surviving_tokens_per_req[decode_mask],
surviving_tokens_per_req[prefill_mask]
])
attn_metadata.qo_indptr = torch.nn.functional.pad(torch.cumsum(reordered_tokens_per_req, dim=0), (1, 0))
# 4. build paged KV cache metadata for surviving requests
original_num_pages_per_req = attn_metadata.paged_kv_indptr.diff()
reordered_num_pages_per_req = original_num_pages_per_req[reordered_req_ids]
page_indices_start = attn_metadata.paged_kv_indptr[reordered_req_ids]
page_indices_end = attn_metadata.paged_kv_indptr[reordered_req_ids + 1]
if new_num_reqs > 0:
# create page indices for each surviving request
page_indices_list = []
for i in range(new_num_reqs):
start_idx = page_indices_start[i]
end_idx = page_indices_end[i]
page_indices_list.append(
attn_metadata.paged_kv_indices[start_idx:end_idx])
attn_metadata.paged_kv_indices = torch.cat(page_indices_list)
else:
# no requests survive SwiftKV selection
attn_metadata.paged_kv_indices = torch.empty(
0,
dtype=attn_metadata.paged_kv_indices.dtype,
device=attn_metadata.paged_kv_indices.device)
# build paged_kv_indptr for surviving requests
attn_metadata.paged_kv_indptr = torch.nn.functional.pad(torch.cumsum(reordered_num_pages_per_req, dim=0), (1, 0)).int()
# update last page lengths for surviving requests
attn_metadata.paged_kv_last_page_len = attn_metadata.paged_kv_last_page_len[reordered_req_ids]
# 5. create reordered logits_indices (decode tokens first, then prefill tokens)
# Map original req_ids to new positions
old_to_new_req_pos = torch.full((surviving_req_ids.max() + 1,), -1,
dtype=torch.long, device=logits_indices.device)
old_to_new_req_pos[reordered_req_ids] = torch.arange(new_num_reqs, device=logits_indices.device)
# Get new request positions for each surviving token
new_req_positions = old_to_new_req_pos[surviving_tokens_flat_req_ids]
# Sort tokens by new request position to get decode tokens first, then prefill tokens
sorted_indices = torch.argsort(new_req_positions)
attn_metadata.swiftkv_inverse_sort_indices = torch.argsort(sorted_indices)
reordered_logits_indices = logits_indices[sorted_indices]
# 6. update other metadata fields
attn_metadata.slot_mapping = attn_metadata.slot_mapping[reordered_logits_indices]
attn_metadata.num_actual_tokens = num_surviving_tokens
attn_metadata.num_decodes = new_num_decodes
attn_metadata.num_prefills = new_num_prefills
attn_metadata.num_decode_tokens = new_num_decode_tokens
attn_metadata.num_prefill_tokens = new_num_prefill_tokens
attn_metadata.use_cascade = False
# cascade attention fields
attn_metadata.shared_qo_indptr = None
attn_metadata.shared_kv_page_indptr = None
attn_metadata.shared_kv_page_indices = None
attn_metadata.shared_kv_last_page_len = None
attn_metadata.cascade_wrapper = None
# 7. re-plan the FlashInfer attention wrappers with new metadata
impl = self.layers[-1].self_attn.attn.impl
if attn_metadata.decode_wrapper and new_num_decodes > 0:
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:new_num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:new_num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
pos_encoding_mode="NONE",
sm_scale=impl.scale,
window_left=impl.sliding_window[0],
logits_soft_cap=impl.logits_soft_cap or 0.0,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
else:
attn_metadata.decode_wrapper = None
# Plan prefill wrapper if we have prefill requests
if attn_metadata.prefill_wrapper and new_num_prefills > 0:
# Prefill starts after decode requests
prefill_start = new_num_decodes
qo_indptr_prefill = attn_metadata.qo_indptr[prefill_start:] - attn_metadata.qo_indptr[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr_prefill,
attn_metadata.paged_kv_indptr[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=impl.scale,
window_left=impl.sliding_window[0],
logits_soft_cap=impl.logits_soft_cap or 0.0,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
else:
attn_metadata.prefill_wrapper = None
return reordered_logits_indices
def swiftkv_select(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
k_states: torch.Tensor,
v_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
attn_metadata = get_attn_metadata_for_swiftkv()
if attn_metadata is None:
# Graph capture or profiling mode.
if hidden_states.shape[0] <= self.cuda_graph_max_batch_size:
# Return the preallocated buffers so cuda graph is captured
# correctly.
inputs = self.decode_runner.inputs
batch_size = hidden_states.shape[0]
padded_size = self.vllm_config.pad_for_cudagraph(batch_size)
return (inputs["hidden_states"][:padded_size],
inputs["residual"][:padded_size],
inputs["positions"][:padded_size],
inputs["k_states"][:padded_size],
inputs["v_states"][:padded_size])
return hidden_states, residual, positions, k_states, v_states
if self.use_custom_ops:
key_caches : List[torch.Tensor] = []
value_caches : List[torch.Tensor] = []
k_scales : List[torch.Tensor] = []
v_scales : List[torch.Tensor] = []
num_heads = self.layers[-1].self_attn.attn.num_kv_heads
head_size = self.layers[-1].self_attn.attn.head_size
for idx, layer in enumerate(
self.layers[self.config.num_key_value_layers:]):
attn = layer.self_attn.attn
kv_cache = attn.kv_cache[forward_context.virtual_engine]
if kv_cache.numel():
# different cache layouts
if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
key_caches.append(kv_cache[:, 0])
value_caches.append(kv_cache[:, 1])
else:
# FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size]
key_caches.append(kv_cache[0])
value_caches.append(kv_cache[1])
k_scales.append(attn._k_scale)
v_scales.append(attn._v_scale)
if len(key_caches) > 0:
from arctic_inference.py_custom_ops import reshape_and_cache_flash_bulk
reshape_and_cache_flash_bulk(
k_states, v_states, key_caches, value_caches,
attn_metadata.slot_mapping, attn.kv_cache_dtype, k_scales,
v_scales, num_heads, head_size)
else:
num_layers = (self.config.num_hidden_layers - self.config.num_key_value_layers)
k_split = k_states.chunk(num_layers, dim=-1)
v_split = v_states.chunk(num_layers, dim=-1)
for idx, layer in enumerate(
self.layers[self.config.num_key_value_layers:]):
attn = layer.self_attn.attn
kv_cache = attn.kv_cache[forward_context.virtual_engine]
if kv_cache.numel():
if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# FlashInfer: [num_blocks, 2, block_size, num_kv_heads, head_size]
k_cache, v_cache = kv_cache.unbind(1)
else:
# FlashAttention: [2, num_blocks, block_size, num_kv_heads, head_size]
k_cache, v_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
k_split[idx].view(-1, attn.num_kv_heads, attn.head_size),
v_split[idx].view(-1, attn.num_kv_heads, attn.head_size),
k_cache,
v_cache,
attn_metadata.slot_mapping,
attn.kv_cache_dtype,
attn._k_scale,
attn._v_scale,
)
logits_indices = attn_metadata.swiftkv_logits_indices
num_surviving_tokens = logits_indices.numel()
if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
# Handle FlashInfer metadata
final_logits_indices = self._fix_flashinfer_metadata(attn_metadata, logits_indices, num_surviving_tokens)
else:
# Handle FlashAttention metadata
self._fix_flash_attention_metadata(attn_metadata, logits_indices, num_surviving_tokens)
final_logits_indices = logits_indices
def index_fn(buffer_name: str, tensor: torch.Tensor,
indices: torch.LongTensor) -> torch.Tensor:
# If the batch size is smaller than the maximum batch size
# for cuda graph, we can use the preallocated buffer.
batch_size = indices.numel()
if batch_size > 0 and batch_size <= self.cuda_graph_max_batch_size:
buffer = self.decode_runner.inputs[buffer_name]
torch.index_select(tensor, 0, indices, out=buffer[:batch_size])
padded_size = self.vllm_config.pad_for_cudagraph(batch_size)
return buffer[:padded_size]
return tensor.index_select(0, indices)
return (index_fn("hidden_states", hidden_states, final_logits_indices),
index_fn("residual", residual, final_logits_indices),
index_fn("positions", positions, final_logits_indices),
index_fn("k_states", k_states, final_logits_indices),
index_fn("v_states", v_states, final_logits_indices))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states, residual, positions, k_states, v_states = (
self.prefill_runner(input_ids, positions))
orig_hidden_states = hidden_states
hidden_states, residual, positions, k_states, v_states = (
self.swiftkv_select(
hidden_states,
residual,
positions,
k_states,
v_states))
with model_runner.set_shift_parallel_mode(True):
hidden_states = self.decode_runner(
hidden_states,
residual,
positions,
k_states,
v_states,
)
attn_metadata = get_attn_metadata_for_swiftkv()
if attn_metadata is not None:
logits_indices = attn_metadata.swiftkv_logits_indices
batch_size = logits_indices.numel()
if FLASHINFER_AVAILABLE and isinstance(attn_metadata, FlashInferMetadata):
inverse_sort_indices = attn_metadata.swiftkv_inverse_sort_indices
orig_hidden_states[logits_indices] = hidden_states[inverse_sort_indices][:batch_size]
else:
orig_hidden_states[logits_indices] = hidden_states[:batch_size]
return orig_hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj.", ".q_proj.", "q"),
(".qkv_proj.", ".k_proj.", "k"),
(".qkv_proj.", ".v_proj.", "v"),
(".gate_up_proj.", ".gate_proj.", 0),
(".gate_up_proj.", ".up_proj.", 1),
(".kv_proj_swiftkv.", ".k_proj_swiftkv.", "k"),
(".kv_proj_swiftkv.", ".v_proj_swiftkv.", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class LlamaSwiftKVForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"kv_proj_swiftkv": ["k_proj_swiftkv", "v_proj_swiftkv"],
}
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = ""):
return LlamaSwiftKVModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
assert intermediate_tensors is None and inputs_embeds is None
model_output = self.model(input_ids, positions)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)