def _fix_flashinfer_metadata()

in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]


    def _fix_flashinfer_metadata(self, attn_metadata, logits_indices, num_surviving_tokens):
        # FlashInfer path
        # 1. get survived requests and get their token counts.
        original_num_tokens = attn_metadata.num_actual_tokens
        token_to_req_id = torch.searchsorted(
            attn_metadata.qo_indptr,
            torch.arange(original_num_tokens,
                         device=logits_indices.device),
            right=True) - 1
        surviving_tokens_flat_req_ids = token_to_req_id[logits_indices]
        surviving_req_ids, surviving_tokens_per_req = torch.unique(surviving_tokens_flat_req_ids, return_counts=True)
        new_num_reqs = surviving_req_ids.numel()

        # 2. classify surviving requests as decode vs prefill
        # decode: exactly 1 token, prefill: > 1 token
        decode_mask = surviving_tokens_per_req == 1
        prefill_mask = surviving_tokens_per_req > 1
        
        decode_req_ids = surviving_req_ids[decode_mask]
        prefill_req_ids = surviving_req_ids[prefill_mask]
        
        new_num_decodes = decode_req_ids.numel()
        new_num_prefills = prefill_req_ids.numel()
        new_num_decode_tokens = decode_mask.sum().item()
        new_num_prefill_tokens = prefill_mask.sum().item()

        # 3. build qo_indptr for surviving requests (decode first, then prefill)
        # Reorder surviving requests: decode first, then prefill
        reordered_req_ids = torch.cat([decode_req_ids, prefill_req_ids])
        reordered_tokens_per_req = torch.cat([
            surviving_tokens_per_req[decode_mask],
            surviving_tokens_per_req[prefill_mask]
        ])
        attn_metadata.qo_indptr = torch.nn.functional.pad(torch.cumsum(reordered_tokens_per_req, dim=0), (1, 0))

        # 4. build paged KV cache metadata for surviving requests
        original_num_pages_per_req = attn_metadata.paged_kv_indptr.diff()
        reordered_num_pages_per_req = original_num_pages_per_req[reordered_req_ids]
        page_indices_start = attn_metadata.paged_kv_indptr[reordered_req_ids]
        page_indices_end = attn_metadata.paged_kv_indptr[reordered_req_ids + 1]

        if new_num_reqs > 0:
            # create page indices for each surviving request
            page_indices_list = []
            for i in range(new_num_reqs):
                start_idx = page_indices_start[i]
                end_idx = page_indices_end[i]
                page_indices_list.append(
                    attn_metadata.paged_kv_indices[start_idx:end_idx])
            attn_metadata.paged_kv_indices = torch.cat(page_indices_list)
        else:
            # no requests survive SwiftKV selection
            attn_metadata.paged_kv_indices = torch.empty(
                0,
                dtype=attn_metadata.paged_kv_indices.dtype,
                device=attn_metadata.paged_kv_indices.device)

        # build paged_kv_indptr for surviving requests
        attn_metadata.paged_kv_indptr = torch.nn.functional.pad(torch.cumsum(reordered_num_pages_per_req, dim=0), (1, 0)).int()
        # update last page lengths for surviving requests
        attn_metadata.paged_kv_last_page_len = attn_metadata.paged_kv_last_page_len[reordered_req_ids]

        # 5. create reordered logits_indices (decode tokens first, then prefill tokens)
        # Map original req_ids to new positions
        old_to_new_req_pos = torch.full((surviving_req_ids.max() + 1,), -1, 
                                       dtype=torch.long, device=logits_indices.device)
        old_to_new_req_pos[reordered_req_ids] = torch.arange(new_num_reqs, device=logits_indices.device)
        
        # Get new request positions for each surviving token
        new_req_positions = old_to_new_req_pos[surviving_tokens_flat_req_ids]
        
        # Sort tokens by new request position to get decode tokens first, then prefill tokens
        sorted_indices = torch.argsort(new_req_positions)
        attn_metadata.swiftkv_inverse_sort_indices = torch.argsort(sorted_indices)
        reordered_logits_indices = logits_indices[sorted_indices]

        # 6. update other metadata fields
        attn_metadata.slot_mapping = attn_metadata.slot_mapping[reordered_logits_indices]
        attn_metadata.num_actual_tokens = num_surviving_tokens
        attn_metadata.num_decodes = new_num_decodes
        attn_metadata.num_prefills = new_num_prefills
        attn_metadata.num_decode_tokens = new_num_decode_tokens
        attn_metadata.num_prefill_tokens = new_num_prefill_tokens
        attn_metadata.use_cascade = False

        # cascade attention fields
        attn_metadata.shared_qo_indptr = None
        attn_metadata.shared_kv_page_indptr = None
        attn_metadata.shared_kv_page_indices = None
        attn_metadata.shared_kv_last_page_len = None
        attn_metadata.cascade_wrapper = None

        # 7. re-plan the FlashInfer attention wrappers with new metadata
        impl = self.layers[-1].self_attn.attn.impl
        
        if attn_metadata.decode_wrapper and new_num_decodes > 0:
            attn_metadata.decode_wrapper.plan(
                attn_metadata.paged_kv_indptr[:new_num_decodes + 1],
                attn_metadata.paged_kv_indices,
                attn_metadata.paged_kv_last_page_len[:new_num_decodes],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                pos_encoding_mode="NONE",
                sm_scale=impl.scale,
                window_left=impl.sliding_window[0],
                logits_soft_cap=impl.logits_soft_cap or 0.0,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.data_type,
                )
        else:
            attn_metadata.decode_wrapper = None
        
        # Plan prefill wrapper if we have prefill requests
        if attn_metadata.prefill_wrapper and new_num_prefills > 0:
            # Prefill starts after decode requests
            prefill_start = new_num_decodes
            qo_indptr_prefill = attn_metadata.qo_indptr[prefill_start:] - attn_metadata.qo_indptr[prefill_start]
            attn_metadata.prefill_wrapper.plan(
                qo_indptr_prefill,
                attn_metadata.paged_kv_indptr[prefill_start:],
                attn_metadata.paged_kv_indices,
                attn_metadata.paged_kv_last_page_len[prefill_start:],
                attn_metadata.num_qo_heads,
                attn_metadata.num_kv_heads,
                attn_metadata.head_dim,
                attn_metadata.page_size,
                causal=True,
                sm_scale=impl.scale,
                window_left=impl.sliding_window[0],
                logits_soft_cap=impl.logits_soft_cap or 0.0,
                q_data_type=attn_metadata.q_data_type,
                kv_data_type=attn_metadata.data_type,
            )
        else:
            attn_metadata.prefill_wrapper = None
        
        return reordered_logits_indices