in src/runtime/relax_vm/paged_kv_cache.cc [823:1095]
void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& opt_token_tree_parent_ptr) final {
// Note: MLA does not supported tree attention for now.
if (attn_kinds_[0] == AttnKind::kMLA) {
CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA";
}
CHECK_EQ(seq_ids.size(), append_lengths.size())
<< "The seq_ids size (" << seq_ids.size() << ") and append_lengths size ("
<< append_lengths.size() << ") mismatch.";
cur_batch_size_ = seq_ids.size();
cur_seq_ids_ = seq_ids;
cur_append_lengths_ = append_lengths;
// - Collect sequence/block/page information for attention.
std::vector<Sequence*> sequences;
std::vector<int32_t> last_block_length_before_append;
is_decode_request_ = true;
sequences.reserve(cur_batch_size_);
last_block_length_before_append.reserve(cur_batch_size_);
k_ragged_rope_pos_offset_host_.clear();
for (int i = 0; i < cur_batch_size_; ++i) {
auto it = seq_map_.find(seq_ids[i]);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
<< "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
last_block_length_before_append.push_back(
global_block_pool_[it->second.last_block_idx].seq_length);
int k_rope_offset = it->second.seq_length;
if (!it->second.accepted_indices_committed) {
int tree_size = static_cast<int>(it->second.token_tree_parent_ptr.size());
k_rope_offset -= tree_size;
}
k_ragged_rope_pos_offset_host_.push_back(k_rope_offset);
it->second.seq_length += append_lengths[i];
if (append_lengths[i] != 1) {
is_decode_request_ = false;
}
}
auto [block_ids_on_depths, trailing_blocks] =
GetBlockIdsOnDepth(sequences, global_block_pool_, cur_batch_size_);
num_depths_ =
std::min(static_cast<int>(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth);
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
std::vector<std::vector<std::pair<int32_t, int32_t>>> chunked_block_ids_arr;
chunked_block_ids_arr.reserve(num_depths_);
use_decode_kernel_.clear();
for (int d = 0; d < num_depths_; ++d) {
// We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
// trailing exceeding blocks.
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(
block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1,
cur_append_lengths_, global_block_pool_, is_decode_request_);
chunked_block_ids_arr.push_back(chunked_block_ids);
use_decode_kernel_.push_back(use_decode_kernel);
}
if (num_depths_ == kPagedKVCacheMaxBlockDepth) {
// Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
// depth must have the same size as current batch.
CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
}
append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back();
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
// When GQA group size is at least 4 and FlashInfer is enabled,
// we always use prefill kernel for better performance.
// Note: For MLA, we always use prefill kernel, so values in `use_decode_kernel` will
// be ignored for MLA.
std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false);
}
bool has_previous_tree =
std::any_of(sequences.begin(), sequences.end(),
[](const Sequence* sequence) { return !sequence->accepted_indices_committed; });
if (has_previous_tree) {
append_before_attn_ = true;
}
// - Check token tree validity and process the token tree.
if (opt_token_tree_parent_ptr.defined()) {
CHECK(!support_sliding_window_) << "Tree attention does not support sliding window.";
CHECK(rope_mode_ != RoPEMode::kInline) << "Tree attention does not support inline RoPE mode.";
ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths,
trailing_blocks);
} else {
// The input batch does not form trees. So each sequence in the batch
// is required to have all past accepted tokens committed.
for (int i = 0; i < cur_batch_size_; ++i) {
Sequence* sequence = sequences[i];
CHECK(sequence->accepted_indices_committed)
<< "The input batch does not form a tree, in which case the sequences in the input "
"batch are expected to have their accepted tokens token tree nodes committed. "
"Please invoke CommitAcceptedTokenTreeNodes for sequence "
<< seq_ids[i];
sequence->is_chain = true;
sequence->token_tree_parent_ptr.clear();
sequence->token_tree_node_depths.clear();
}
std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
}
if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
// data structure with regard to the page table after appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}
for (int d = 0; d < num_depths_; ++d) {
HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d];
HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d];
HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d];
HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d];
HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d];
HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d];
HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d];
qo_indptr_h.clear();
page_indptr_h.clear();
page_indices_h.clear();
last_page_len_h.clear();
sliding_window_offset_h.clear();
sink_size_h.clear();
k_rope_pos_offset_h.clear();
qo_indptr_h.push_back(0);
page_indptr_h.push_back(0);
for (int i = 0; i < static_cast<int>(chunked_block_ids_arr[d].size()); ++i) {
const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
if (block_id == -1) {
page_indptr_h.push_back(page_indptr_h.back());
last_page_len_h.push_back(0);
sliding_window_offset_h.push_back(0);
sink_size_h.push_back(0);
k_rope_pos_offset_h.push_back(0);
} else {
if (d < kPagedKVCacheMaxBlockDepth - 1) {
// Blocks not at maximum depth
const Block& block = global_block_pool_[block_id];
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
last_page_len_h.push_back(
block.seq_length == 0
? 0
: (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
page_size_ +
1);
sliding_window_offset_h.push_back(block.sliding_window_offset);
sink_size_h.push_back(block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
} else {
// Blocks at maximum depth
const Block& block = global_block_pool_[block_id];
int32_t num_pages = static_cast<int32_t>(block.page_ids.size());
int32_t total_seq_length = static_cast<int32_t>(block.seq_length);
int32_t last_block_id = block_id;
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
for (int32_t id : trailing_blocks[i]) {
// Collect trailing blocks if available
const Block& block = global_block_pool_[id];
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
num_pages += block.page_ids.size();
total_seq_length += block.seq_length;
last_block_id = id;
}
page_indptr_h.push_back(page_indptr_h.back() + num_pages);
const Block& last_block = global_block_pool_[last_block_id];
last_page_len_h.push_back(total_seq_length == 0
? 0
: (total_seq_length - last_block.sink_length +
last_block.sliding_window_offset - 1) %
page_size_ +
1);
sliding_window_offset_h.push_back(last_block.sliding_window_offset);
sink_size_h.push_back(last_block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
}
}
}
}
if (!append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is not 1, we create the auxiliary
// data structure with regard to the page table before appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}
// Map each the token position in the input batch to the position
// in the global KV cache. The mapping is used in when appending k/v values.
q_rope_position_map_host_.clear();
append_position_map_host_.clear();
kv_transfer_remote_position_map_host_.clear();
kv_transfer_recver_id_host_.clear();
kv_transfer_page_to_page_local_position_map_host_.clear();
kv_transfer_page_to_page_remote_position_map_host_.clear();
kv_transfer_page_to_page_recver_id_host_.clear();
transfer_kv_ = false;
page_to_page_transfer_kv_ = false;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = append_lengths[i];
const Block& block = global_block_pool_[sequences[i]->last_block_idx];
for (int64_t pos = 0; pos < append_length; ++pos) {
if (sequences[i]->token_tree_node_depths.empty()) {
q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos);
} else {
int64_t offset_in_tree =
static_cast<int64_t>(sequences[i]->token_tree_parent_ptr.size()) - append_length;
ICHECK_GE(offset_in_tree, 0);
q_rope_position_map_host_.push_back(
k_ragged_rope_pos_offset_host_[i] +
sequences[i]->token_tree_node_depths[offset_in_tree + pos]);
}
int32_t pos_in_block = block.seq_length - append_length + pos;
if (last_block_length_before_append[i] + pos < block.sink_length) {
// The location to write is part of the attention sink.
int32_t offset_in_block = last_block_length_before_append[i] + pos;
append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
page_size_ +
offset_in_block % page_size_);
} else if (pos_in_block < block.sink_length) {
// The location to write is pinned by attn sink before the append.
// Therefore we cannot write into the location.
append_position_map_host_.push_back(-1);
} else {
// The location to write is in the sliding window.
int32_t offset_in_block = pos_in_block - block.sink_length + block.sliding_window_offset;
append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
page_size_ +
offset_in_block % page_size_);
}
int64_t pos_in_seq = sequences[i]->seq_length - append_length + pos;
int64_t seq_send_start = sequences[i]->kv_transfer_metadata.start;
if (pos_in_seq < seq_send_start) {
kv_transfer_remote_position_map_host_.push_back(-1);
kv_transfer_recver_id_host_.push_back(-1);
} else {
transfer_kv_ = true;
kv_transfer_remote_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - seq_send_start]);
kv_transfer_recver_id_host_.push_back(
sequences[i]->kv_transfer_metadata.recver_pe_offset);
}
}
if (!sequences[i]->kv_transfer_metadata.local_position_map.empty()) {
page_to_page_transfer_kv_ = true;
for (int pos = 0;
pos < static_cast<int>(sequences[i]->kv_transfer_metadata.local_position_map.size());
++pos) {
kv_transfer_page_to_page_local_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.local_position_map[pos]);
kv_transfer_page_to_page_remote_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.remote_position_map[pos]);
kv_transfer_page_to_page_recver_id_host_.push_back(
sequences[i]->kv_transfer_metadata.recver_pe_offset);
}
sequences[i]->kv_transfer_metadata.local_position_map.clear();
}
}
}