void BeginForward()

in src/runtime/relax_vm/paged_kv_cache.cc [823:1095]


  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
                    const Optional<IntTuple>& opt_token_tree_parent_ptr) final {
    // Note: MLA does not supported tree attention for now.
    if (attn_kinds_[0] == AttnKind::kMLA) {
      CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA";
    }

    CHECK_EQ(seq_ids.size(), append_lengths.size())
        << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size ("
        << append_lengths.size() << ") mismatch.";
    cur_batch_size_ = seq_ids.size();
    cur_seq_ids_ = seq_ids;
    cur_append_lengths_ = append_lengths;

    // - Collect sequence/block/page information for attention.
    std::vector<Sequence*> sequences;
    std::vector<int32_t> last_block_length_before_append;
    is_decode_request_ = true;
    sequences.reserve(cur_batch_size_);
    last_block_length_before_append.reserve(cur_batch_size_);
    k_ragged_rope_pos_offset_host_.clear();
    for (int i = 0; i < cur_batch_size_; ++i) {
      auto it = seq_map_.find(seq_ids[i]);
      CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
                                  << "\" cannot be found in KV cache.";
      sequences.push_back(&it->second);
      last_block_length_before_append.push_back(
          global_block_pool_[it->second.last_block_idx].seq_length);
      int k_rope_offset = it->second.seq_length;
      if (!it->second.accepted_indices_committed) {
        int tree_size = static_cast<int>(it->second.token_tree_parent_ptr.size());
        k_rope_offset -= tree_size;
      }
      k_ragged_rope_pos_offset_host_.push_back(k_rope_offset);
      it->second.seq_length += append_lengths[i];
      if (append_lengths[i] != 1) {
        is_decode_request_ = false;
      }
    }

    auto [block_ids_on_depths, trailing_blocks] =
        GetBlockIdsOnDepth(sequences, global_block_pool_, cur_batch_size_);
    num_depths_ =
        std::min(static_cast<int>(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth);
    ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);

    std::vector<std::vector<std::pair<int32_t, int32_t>>> chunked_block_ids_arr;
    chunked_block_ids_arr.reserve(num_depths_);
    use_decode_kernel_.clear();
    for (int d = 0; d < num_depths_; ++d) {
      // We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
      // trailing exceeding blocks.
      auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(
          block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1,
          cur_append_lengths_, global_block_pool_, is_decode_request_);
      chunked_block_ids_arr.push_back(chunked_block_ids);
      use_decode_kernel_.push_back(use_decode_kernel);
    }

    if (num_depths_ == kPagedKVCacheMaxBlockDepth) {
      // Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
      // depth must have the same size as current batch.
      CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
    }

    append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back();
    if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
      // When GQA group size is at least 4 and FlashInfer is enabled,
      // we always use prefill kernel for better performance.
      // Note: For MLA, we always use prefill kernel, so values in `use_decode_kernel` will
      // be ignored for MLA.
      std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false);
    }

    bool has_previous_tree =
        std::any_of(sequences.begin(), sequences.end(),
                    [](const Sequence* sequence) { return !sequence->accepted_indices_committed; });
    if (has_previous_tree) {
      append_before_attn_ = true;
    }

    // - Check token tree validity and process the token tree.
    if (opt_token_tree_parent_ptr.defined()) {
      CHECK(!support_sliding_window_) << "Tree attention does not support sliding window.";
      CHECK(rope_mode_ != RoPEMode::kInline) << "Tree attention does not support inline RoPE mode.";
      ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths,
                             trailing_blocks);
    } else {
      // The input batch does not form trees. So each sequence in the batch
      // is required to have all past accepted tokens committed.
      for (int i = 0; i < cur_batch_size_; ++i) {
        Sequence* sequence = sequences[i];
        CHECK(sequence->accepted_indices_committed)
            << "The input batch does not form a tree, in which case the sequences in the input "
               "batch are expected to have their accepted tokens token tree nodes committed. "
               "Please invoke CommitAcceptedTokenTreeNodes for sequence "
            << seq_ids[i];
        sequence->is_chain = true;
        sequence->token_tree_parent_ptr.clear();
        sequence->token_tree_node_depths.clear();
      }
      std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
    }

    if (append_before_attn_) {
      // Right now we use different kernels when depth is 1 or not 1.
      // For the case where maximum depth is 1, we create the auxiliary
      // data structure with regard to the page table after appending.
      for (int i = 0; i < cur_batch_size_; ++i) {
        ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
      }
    }

    for (int d = 0; d < num_depths_; ++d) {
      HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d];
      HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d];
      HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d];
      HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d];
      HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d];
      HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d];
      HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d];
      qo_indptr_h.clear();
      page_indptr_h.clear();
      page_indices_h.clear();
      last_page_len_h.clear();
      sliding_window_offset_h.clear();
      sink_size_h.clear();
      k_rope_pos_offset_h.clear();
      qo_indptr_h.push_back(0);
      page_indptr_h.push_back(0);
      for (int i = 0; i < static_cast<int>(chunked_block_ids_arr[d].size()); ++i) {
        const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
        qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
        if (block_id == -1) {
          page_indptr_h.push_back(page_indptr_h.back());
          last_page_len_h.push_back(0);
          sliding_window_offset_h.push_back(0);
          sink_size_h.push_back(0);
          k_rope_pos_offset_h.push_back(0);
        } else {
          if (d < kPagedKVCacheMaxBlockDepth - 1) {
            // Blocks not at maximum depth
            const Block& block = global_block_pool_[block_id];
            page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
            for (int32_t page_id : block.page_ids) {
              page_indices_h.push_back(page_id);
            }
            last_page_len_h.push_back(
                block.seq_length == 0
                    ? 0
                    : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
                              page_size_ +
                          1);
            sliding_window_offset_h.push_back(block.sliding_window_offset);
            sink_size_h.push_back(block.sink_length);
            k_rope_pos_offset_h.push_back(block.start_pos);
          } else {
            // Blocks at maximum depth
            const Block& block = global_block_pool_[block_id];
            int32_t num_pages = static_cast<int32_t>(block.page_ids.size());
            int32_t total_seq_length = static_cast<int32_t>(block.seq_length);
            int32_t last_block_id = block_id;
            for (int32_t page_id : block.page_ids) {
              page_indices_h.push_back(page_id);
            }
            for (int32_t id : trailing_blocks[i]) {
              // Collect trailing blocks if available
              const Block& block = global_block_pool_[id];
              for (int32_t page_id : block.page_ids) {
                page_indices_h.push_back(page_id);
              }
              num_pages += block.page_ids.size();
              total_seq_length += block.seq_length;
              last_block_id = id;
            }
            page_indptr_h.push_back(page_indptr_h.back() + num_pages);
            const Block& last_block = global_block_pool_[last_block_id];
            last_page_len_h.push_back(total_seq_length == 0
                                          ? 0
                                          : (total_seq_length - last_block.sink_length +
                                             last_block.sliding_window_offset - 1) %
                                                    page_size_ +
                                                1);
            sliding_window_offset_h.push_back(last_block.sliding_window_offset);
            sink_size_h.push_back(last_block.sink_length);
            k_rope_pos_offset_h.push_back(block.start_pos);
          }
        }
      }
    }

    if (!append_before_attn_) {
      // Right now we use different kernels when depth is 1 or not 1.
      // For the case where maximum depth is not 1, we create the auxiliary
      // data structure with regard to the page table before appending.
      for (int i = 0; i < cur_batch_size_; ++i) {
        ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
      }
    }

    // Map each the token position in the input batch to the position
    // in the global KV cache. The mapping is used in when appending k/v values.
    q_rope_position_map_host_.clear();
    append_position_map_host_.clear();
    kv_transfer_remote_position_map_host_.clear();
    kv_transfer_recver_id_host_.clear();
    kv_transfer_page_to_page_local_position_map_host_.clear();
    kv_transfer_page_to_page_remote_position_map_host_.clear();
    kv_transfer_page_to_page_recver_id_host_.clear();
    transfer_kv_ = false;
    page_to_page_transfer_kv_ = false;
    for (int i = 0; i < cur_batch_size_; ++i) {
      int64_t append_length = append_lengths[i];
      const Block& block = global_block_pool_[sequences[i]->last_block_idx];
      for (int64_t pos = 0; pos < append_length; ++pos) {
        if (sequences[i]->token_tree_node_depths.empty()) {
          q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos);
        } else {
          int64_t offset_in_tree =
              static_cast<int64_t>(sequences[i]->token_tree_parent_ptr.size()) - append_length;
          ICHECK_GE(offset_in_tree, 0);
          q_rope_position_map_host_.push_back(
              k_ragged_rope_pos_offset_host_[i] +
              sequences[i]->token_tree_node_depths[offset_in_tree + pos]);
        }

        int32_t pos_in_block = block.seq_length - append_length + pos;
        if (last_block_length_before_append[i] + pos < block.sink_length) {
          // The location to write is part of the attention sink.
          int32_t offset_in_block = last_block_length_before_append[i] + pos;
          append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
                                                  page_size_ +
                                              offset_in_block % page_size_);
        } else if (pos_in_block < block.sink_length) {
          // The location to write is pinned by attn sink before the append.
          // Therefore we cannot write into the location.
          append_position_map_host_.push_back(-1);
        } else {
          // The location to write is in the sliding window.
          int32_t offset_in_block = pos_in_block - block.sink_length + block.sliding_window_offset;
          append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
                                                  page_size_ +
                                              offset_in_block % page_size_);
        }
        int64_t pos_in_seq = sequences[i]->seq_length - append_length + pos;
        int64_t seq_send_start = sequences[i]->kv_transfer_metadata.start;
        if (pos_in_seq < seq_send_start) {
          kv_transfer_remote_position_map_host_.push_back(-1);
          kv_transfer_recver_id_host_.push_back(-1);
        } else {
          transfer_kv_ = true;
          kv_transfer_remote_position_map_host_.push_back(
              sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - seq_send_start]);
          kv_transfer_recver_id_host_.push_back(
              sequences[i]->kv_transfer_metadata.recver_pe_offset);
        }
      }
      if (!sequences[i]->kv_transfer_metadata.local_position_map.empty()) {
        page_to_page_transfer_kv_ = true;
        for (int pos = 0;
             pos < static_cast<int>(sequences[i]->kv_transfer_metadata.local_position_map.size());
             ++pos) {
          kv_transfer_page_to_page_local_position_map_host_.push_back(
              sequences[i]->kv_transfer_metadata.local_position_map[pos]);
          kv_transfer_page_to_page_remote_position_map_host_.push_back(
              sequences[i]->kv_transfer_metadata.remote_position_map[pos]);
          kv_transfer_page_to_page_recver_id_host_.push_back(
              sequences[i]->kv_transfer_metadata.recver_pe_offset);
        }
        sequences[i]->kv_transfer_metadata.local_position_map.clear();
      }
    }
  }