in src/runtime/relax_vm/paged_kv_cache.cc [273:493]
explicit PagedAttentionKVCacheObj(
int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset,
int64_t layer_id_end_offset, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim,
int64_t v_head_dim, std::vector<AttnKind> attn_kinds, int64_t reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta,
Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device,
Optional<PackedFunc> f_transpose_append_mha, Optional<PackedFunc> f_transpose_append_mla,
PackedFunc f_compact_copy, std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged,
std::unique_ptr<PagedPrefillFunc> f_attention_prefill,
std::unique_ptr<PagedDecodeFunc> f_attention_decode,
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window,
std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window,
std::unique_ptr<PagedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_paged_kv,
std::unique_ptr<RaggedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask,
std::unique_ptr<PagedPrefillFunc> f_mla_prefill, Array<PackedFunc> f_merge_inplace,
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, PackedFunc f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
layer_id_begin_offset_(layer_id_begin_offset),
layer_id_end_offset_(layer_id_end_offset),
num_qo_heads_(num_qo_heads),
num_kv_heads_(num_kv_heads),
qk_head_dim_(qk_head_dim),
v_head_dim_(v_head_dim),
num_total_pages_(num_total_pages),
prefill_chunk_size_(prefill_chunk_size),
support_sliding_window_(support_sliding_window),
attn_kinds_(std::move(attn_kinds)),
rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline
: rope_mode),
rotary_scale_(rotary_scale),
rotary_theta_(rotary_theta),
rope_ext_factors_(std::move(rope_ext_factors)),
kv_dtype_(DataType(dtype)),
f_transpose_append_mha_(std::move(f_transpose_append_mha)),
f_transpose_append_mla_(std::move(f_transpose_append_mla)),
f_compact_copy_(std::move(f_compact_copy)),
f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)),
f_attention_prefill_(std::move(f_attention_prefill)),
f_attention_decode_(std::move(f_attention_decode)),
f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)),
f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)),
f_attention_prefill_with_tree_mask_paged_kv_(
std::move(f_attention_prefill_with_tree_mask_paged_kv)),
f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)),
f_mla_prefill_(std::move(f_mla_prefill)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
// Note: For MLA, sliding window and disaggregation are disabled for now.
if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMLA) != attn_kinds_.end()) {
CHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA";
CHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA";
}
pages_.reserve(num_layers);
if (enable_kv_transfer) {
// For now, KV transfer only supports MHA.
for (AttnKind attn_kind : attn_kinds_) {
CHECK(attn_kind == AttnKind::kMHA);
}
const auto f_nvshmem_init =
tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.init_nvshmem");
CHECK(f_nvshmem_init.has_value())
<< "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when compiling TVM.";
const auto f_nvshmem_empty = tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.empty");
ICHECK(f_nvshmem_empty.has_value());
nvshmem_pages_ =
(*f_nvshmem_empty)(
ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}),
dtype, device)
.cast<NDArray>();
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(nvshmem_pages_.CreateView(
{num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype,
i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ *
nvshmem_pages_.DataType().bytes()));
}
const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer");
const auto f_transfer_kv_page_to_page_ptr =
tvm::ffi::Function::GetGlobal("nvshmem.KVTransferPageToPage");
ICHECK(f_transfer_kv_ptr.has_value());
ICHECK(f_transfer_kv_page_to_page_ptr.has_value());
f_transfer_kv_ = *f_transfer_kv_ptr;
f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
} else {
for (int i = 0; i < num_layers; ++i) {
ShapeTuple kv_cache_shape =
GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages,
reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim);
pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device));
}
}
// Allocate the host memory.
Device preferred_host_device = GetPreferredHostDevice(device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
qo_indptr_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
page_indptr_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
page_indices_on_depths_host_.push_back(
HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device));
last_page_len_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
sliding_window_offset_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
sink_size_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
k_rope_pos_offset_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs,
dtype_aux_, preferred_host_device));
tree_attn_mn_indptr_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
}
k_ragged_rope_pos_offset_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
q_rope_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
append_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_remote_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_recver_id_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_local_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_remote_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_recver_id_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
cur_append_lengths_indptr_host_ =
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
commit_copy_length_indptr_host_ =
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
commit_copy_src_pos_in_page_table_host_ =
HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
dtype_aux_, preferred_host_device);
commit_copy_dst_pos_in_page_table_host_ =
HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
dtype_aux_, preferred_host_device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
if (NeedKernelBeginForward()) {
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
temp_int_pinned_attn_workspace_.push_back(NDArray::Empty(
{kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
}
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
page_indices_on_depths_view_.push_back(NDArray());
length_info_on_depths_view_.push_back(NDArray());
k_rope_pos_offset_view_.push_back(NDArray());
tree_attn_mask_view_.push_back(NDArray());
tree_attn_mn_indptr_view_.push_back(NDArray());
is_chain_on_depths_.push_back(true);
}
// Additional workspace for the "prefill with ragged kv" kernel.
if (NeedKernelBeginForward()) {
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
temp_int_pinned_attn_workspace_.push_back(NDArray::Empty(
{kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
temp_float_attn_workspace_ =
NDArray::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device);
}
if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) {
temp_attn_q_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device);
temp_attn_k_device_ =
NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device);
temp_attn_v_device_ =
NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device);
}
temp_attn_output_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device);
temp_attn_lse_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
merged_attn_lse_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) {
free_page_ids_.push_back(page_id);
}
// If the device is CUDA/ROCm, we create a standalone copy stream, in
// purpose to hide the latency of auxiliary stream copy.
if (device.device_type == DLDeviceType::kDLCUDA ||
device.device_type == DLDeviceType::kDLROCM) {
// The compute stream is the default stream.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
kv_transfer_stream_ = DeviceAPI::Get(device)->CreateStream(device);
}
// Create the auxiliary data manager for attention.
// We only use the merged aux data for CUDA, since direct pointer
// operations may have issues on other platforms.
if (device_.device_type == DLDeviceType::kDLCUDA ||
device_.device_type == DLDeviceType::kDLCPU) {
aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
preferred_host_device, copy_stream_);
} else {
aux_data_manager_ = std::make_unique<PlainPagedKVCacheAuxDataManager>(
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
preferred_host_device, copy_stream_);
}
// Right now only the "normal" RoPE mode supports the RoPE extention factors.
if (rope_ext_factors_.defined()) {
CHECK(rope_mode_ == RoPEMode::kNormal)
<< "The RoPE mode must be normal to support RoPE extension factors.";
}
}