explicit PagedAttentionKVCacheObj()

in src/runtime/relax_vm/paged_kv_cache.cc [273:493]


  explicit PagedAttentionKVCacheObj(
      int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset,
      int64_t layer_id_end_offset, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim,
      int64_t v_head_dim, std::vector<AttnKind> attn_kinds, int64_t reserved_num_seqs,
      int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window,
      RoPEMode rope_mode, double rotary_scale, double rotary_theta,
      Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device,
      Optional<PackedFunc> f_transpose_append_mha, Optional<PackedFunc> f_transpose_append_mla,
      PackedFunc f_compact_copy, std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged,
      std::unique_ptr<PagedPrefillFunc> f_attention_prefill,
      std::unique_ptr<PagedDecodeFunc> f_attention_decode,
      std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window,
      std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window,
      std::unique_ptr<PagedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_paged_kv,
      std::unique_ptr<RaggedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask,
      std::unique_ptr<PagedPrefillFunc> f_mla_prefill, Array<PackedFunc> f_merge_inplace,
      PackedFunc f_split_rotary, PackedFunc f_copy_single_page, PackedFunc f_debug_get_kv)
      : page_size_(page_size),
        num_layers_(num_layers),
        layer_id_begin_offset_(layer_id_begin_offset),
        layer_id_end_offset_(layer_id_end_offset),
        num_qo_heads_(num_qo_heads),
        num_kv_heads_(num_kv_heads),
        qk_head_dim_(qk_head_dim),
        v_head_dim_(v_head_dim),
        num_total_pages_(num_total_pages),
        prefill_chunk_size_(prefill_chunk_size),
        support_sliding_window_(support_sliding_window),
        attn_kinds_(std::move(attn_kinds)),
        rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline
                                                                          : rope_mode),
        rotary_scale_(rotary_scale),
        rotary_theta_(rotary_theta),
        rope_ext_factors_(std::move(rope_ext_factors)),
        kv_dtype_(DataType(dtype)),
        f_transpose_append_mha_(std::move(f_transpose_append_mha)),
        f_transpose_append_mla_(std::move(f_transpose_append_mla)),
        f_compact_copy_(std::move(f_compact_copy)),
        f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)),
        f_attention_prefill_(std::move(f_attention_prefill)),
        f_attention_decode_(std::move(f_attention_decode)),
        f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)),
        f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)),
        f_attention_prefill_with_tree_mask_paged_kv_(
            std::move(f_attention_prefill_with_tree_mask_paged_kv)),
        f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)),
        f_mla_prefill_(std::move(f_mla_prefill)),
        f_merge_inplace_(std::move(f_merge_inplace)),
        f_split_rotary_(std::move(f_split_rotary)),
        f_copy_single_page_(std::move(f_copy_single_page)),
        f_debug_get_kv_(std::move(f_debug_get_kv)),
        device_(device) {
    // Note: For MLA, sliding window and disaggregation are disabled for now.
    if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMLA) != attn_kinds_.end()) {
      CHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA";
      CHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA";
    }

    pages_.reserve(num_layers);
    if (enable_kv_transfer) {
      // For now, KV transfer only supports MHA.
      for (AttnKind attn_kind : attn_kinds_) {
        CHECK(attn_kind == AttnKind::kMHA);
      }
      const auto f_nvshmem_init =
          tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.init_nvshmem");
      CHECK(f_nvshmem_init.has_value())
          << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when compiling TVM.";
      const auto f_nvshmem_empty = tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.empty");
      ICHECK(f_nvshmem_empty.has_value());
      nvshmem_pages_ =
          (*f_nvshmem_empty)(
              ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}),
              dtype, device)
              .cast<NDArray>();
      for (int i = 0; i < num_layers; ++i) {
        pages_.push_back(nvshmem_pages_.CreateView(
            {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype,
            i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ *
                nvshmem_pages_.DataType().bytes()));
      }

      const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer");
      const auto f_transfer_kv_page_to_page_ptr =
          tvm::ffi::Function::GetGlobal("nvshmem.KVTransferPageToPage");
      ICHECK(f_transfer_kv_ptr.has_value());
      ICHECK(f_transfer_kv_page_to_page_ptr.has_value());
      f_transfer_kv_ = *f_transfer_kv_ptr;
      f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
    } else {
      for (int i = 0; i < num_layers; ++i) {
        ShapeTuple kv_cache_shape =
            GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages,
                            reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim);
        pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device));
      }
    }

    // Allocate the host memory.
    Device preferred_host_device = GetPreferredHostDevice(device);
    for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
      qo_indptr_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
      page_indptr_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
      page_indices_on_depths_host_.push_back(
          HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device));
      last_page_len_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
      sliding_window_offset_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
      sink_size_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
      k_rope_pos_offset_on_depths_host_.push_back(
          HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
      tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs,
                                                      dtype_aux_, preferred_host_device));
      tree_attn_mn_indptr_host_.push_back(
          HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
    }
    k_ragged_rope_pos_offset_host_ =
        HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
    q_rope_position_map_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    append_position_map_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    kv_transfer_remote_position_map_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    kv_transfer_recver_id_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    kv_transfer_page_to_page_local_position_map_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    kv_transfer_page_to_page_remote_position_map_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    kv_transfer_page_to_page_recver_id_host_ =
        HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
    cur_append_lengths_indptr_host_ =
        HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
    commit_copy_length_indptr_host_ =
        HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
    commit_copy_src_pos_in_page_table_host_ =
        HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
                         dtype_aux_, preferred_host_device);
    commit_copy_dst_pos_in_page_table_host_ =
        HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
                         dtype_aux_, preferred_host_device);

    for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
      if (NeedKernelBeginForward()) {
        temp_int_attn_workspace_.push_back(
            NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
        temp_int_pinned_attn_workspace_.push_back(NDArray::Empty(
            {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
      }
      qo_indptr_on_depths_view_.push_back(NDArray());
      page_indptr_on_depths_view_.push_back(NDArray());
      page_indices_on_depths_view_.push_back(NDArray());
      length_info_on_depths_view_.push_back(NDArray());
      k_rope_pos_offset_view_.push_back(NDArray());
      tree_attn_mask_view_.push_back(NDArray());
      tree_attn_mn_indptr_view_.push_back(NDArray());
      is_chain_on_depths_.push_back(true);
    }
    // Additional workspace for the "prefill with ragged kv" kernel.
    if (NeedKernelBeginForward()) {
      temp_int_attn_workspace_.push_back(
          NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
      temp_int_pinned_attn_workspace_.push_back(NDArray::Empty(
          {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
      temp_float_attn_workspace_ =
          NDArray::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device);
    }

    if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) {
      temp_attn_q_device_ =
          NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device);
      temp_attn_k_device_ =
          NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device);
      temp_attn_v_device_ =
          NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device);
    }
    temp_attn_output_device_ =
        NDArray::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device);
    temp_attn_lse_device_ =
        NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
    merged_attn_lse_device_ =
        NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
    for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) {
      free_page_ids_.push_back(page_id);
    }

    // If the device is CUDA/ROCm, we create a standalone copy stream, in
    // purpose to hide the latency of auxiliary stream copy.
    if (device.device_type == DLDeviceType::kDLCUDA ||
        device.device_type == DLDeviceType::kDLROCM) {
      // The compute stream is the default stream.
      compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
      copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
      kv_transfer_stream_ = DeviceAPI::Get(device)->CreateStream(device);
    }

    // Create the auxiliary data manager for attention.
    // We only use the merged aux data for CUDA, since direct pointer
    // operations may have issues on other platforms.
    if (device_.device_type == DLDeviceType::kDLCUDA ||
        device_.device_type == DLDeviceType::kDLCPU) {
      aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
          reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
          preferred_host_device, copy_stream_);
    } else {
      aux_data_manager_ = std::make_unique<PlainPagedKVCacheAuxDataManager>(
          reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
          preferred_host_device, copy_stream_);
    }

    // Right now only the "normal" RoPE mode supports the RoPE extention factors.
    if (rope_ext_factors_.defined()) {
      CHECK(rope_mode_ == RoPEMode::kNormal)
          << "The RoPE mode must be normal to support RoPE extension factors.";
    }
  }