in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]
def _fix_flashinfer_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
# FlashInfer path
# 1. get survived requests and get their token counts.
original_num_tokens = attn_metadata.num_actual_tokens
token_to_req_id = torch.searchsorted(
attn_metadata.qo_indptr,
torch.arange(original_num_tokens,
device=logits_indices.device),
right=True) - 1
surviving_tokens_flat_req_ids = token_to_req_id[logits_indices]
surviving_req_ids, surviving_tokens_per_req = torch.unique(surviving_tokens_flat_req_ids, return_counts=True)
new_num_reqs = surviving_req_ids.numel()
# 2. classify surviving requests as decode vs prefill
# decode: exactly 1 token, prefill: > 1 token
decode_mask = surviving_tokens_per_req == 1
prefill_mask = surviving_tokens_per_req > 1
decode_req_ids = surviving_req_ids[decode_mask]
prefill_req_ids = surviving_req_ids[prefill_mask]
new_num_decodes = decode_req_ids.numel()
new_num_prefills = prefill_req_ids.numel()
new_num_decode_tokens = decode_mask.sum().item()
new_num_prefill_tokens = prefill_mask.sum().item()
# 3. build qo_indptr for surviving requests (decode first, then prefill)
# Reorder surviving requests: decode first, then prefill
reordered_req_ids = torch.cat([decode_req_ids, prefill_req_ids])
reordered_tokens_per_req = torch.cat([
surviving_tokens_per_req[decode_mask],
surviving_tokens_per_req[prefill_mask]
])
attn_metadata.qo_indptr = torch.nn.functional.pad(torch.cumsum(reordered_tokens_per_req, dim=0), (1, 0))
# 4. build paged KV cache metadata for surviving requests
original_num_pages_per_req = attn_metadata.paged_kv_indptr.diff()
reordered_num_pages_per_req = original_num_pages_per_req[reordered_req_ids]
page_indices_start = attn_metadata.paged_kv_indptr[reordered_req_ids]
page_indices_end = attn_metadata.paged_kv_indptr[reordered_req_ids + 1]
if new_num_reqs > 0:
# create page indices for each surviving request
page_indices_list = []
for i in range(new_num_reqs):
start_idx = page_indices_start[i]
end_idx = page_indices_end[i]
page_indices_list.append(
attn_metadata.paged_kv_indices[start_idx:end_idx])
attn_metadata.paged_kv_indices = torch.cat(page_indices_list)
else:
# no requests survive SwiftKV selection
attn_metadata.paged_kv_indices = torch.empty(
0,
dtype=attn_metadata.paged_kv_indices.dtype,
device=attn_metadata.paged_kv_indices.device)
# build paged_kv_indptr for surviving requests
attn_metadata.paged_kv_indptr = torch.nn.functional.pad(torch.cumsum(reordered_num_pages_per_req, dim=0), (1, 0)).int()
# update last page lengths for surviving requests
attn_metadata.paged_kv_last_page_len = attn_metadata.paged_kv_last_page_len[reordered_req_ids]
# 5. create reordered logits_indices (decode tokens first, then prefill tokens)
# Map original req_ids to new positions
old_to_new_req_pos = torch.full((surviving_req_ids.max() + 1,), -1,
dtype=torch.long, device=logits_indices.device)
old_to_new_req_pos[reordered_req_ids] = torch.arange(new_num_reqs, device=logits_indices.device)
# Get new request positions for each surviving token
new_req_positions = old_to_new_req_pos[surviving_tokens_flat_req_ids]
# Sort tokens by new request position to get decode tokens first, then prefill tokens
sorted_indices = torch.argsort(new_req_positions)
attn_metadata.swiftkv_inverse_sort_indices = torch.argsort(sorted_indices)
reordered_logits_indices = logits_indices[sorted_indices]
# 6. update other metadata fields
attn_metadata.slot_mapping = attn_metadata.slot_mapping[reordered_logits_indices]
attn_metadata.num_actual_tokens = num_surviving_tokens
attn_metadata.num_decodes = new_num_decodes
attn_metadata.num_prefills = new_num_prefills
attn_metadata.num_decode_tokens = new_num_decode_tokens
attn_metadata.num_prefill_tokens = new_num_prefill_tokens
attn_metadata.use_cascade = False
# cascade attention fields
attn_metadata.shared_qo_indptr = None
attn_metadata.shared_kv_page_indptr = None
attn_metadata.shared_kv_page_indices = None
attn_metadata.shared_kv_last_page_len = None
attn_metadata.cascade_wrapper = None
# 7. re-plan the FlashInfer attention wrappers with new metadata
impl = self.layers[-1].self_attn.attn.impl
if attn_metadata.decode_wrapper and new_num_decodes > 0:
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:new_num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:new_num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
pos_encoding_mode="NONE",
sm_scale=impl.scale,
window_left=impl.sliding_window[0],
logits_soft_cap=impl.logits_soft_cap or 0.0,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
else:
attn_metadata.decode_wrapper = None
# Plan prefill wrapper if we have prefill requests
if attn_metadata.prefill_wrapper and new_num_prefills > 0:
# Prefill starts after decode requests
prefill_start = new_num_decodes
qo_indptr_prefill = attn_metadata.qo_indptr[prefill_start:] - attn_metadata.qo_indptr[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr_prefill,
attn_metadata.paged_kv_indptr[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=impl.scale,
window_left=impl.sliding_window[0],
logits_soft_cap=impl.logits_soft_cap or 0.0,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
else:
attn_metadata.prefill_wrapper = None
return reordered_logits_indices