maga_transformer/cpp/devices/cuda_impl/CudaFlashInfer.cc (469 lines of code) (raw):

#include <iostream> #include <numeric> #include "maga_transformer/cpp/devices/OpData.h" #include "maga_transformer/cpp/devices/cuda_impl/CudaDevice.h" #include "maga_transformer/cpp/devices/cuda_impl/CudaFlashInfer.h" #include "maga_transformer/cpp/devices/CommonDefines.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" #include "maga_transformer/cpp/cuda/Dispatch.h" #include "maga_transformer/cpp/utils/compiler_config.h" #include "maga_transformer/cpp/kernels/kv_cache/kv_cache_utils.h" #include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h" #include "3rdparty/flashinfer/flashinfer.h" #include "flashmla/flashmla.h" using namespace std; using namespace rtp_llm; namespace rtp_llm { using Slice = torch::indexing::Slice; constexpr auto TNone = torch::indexing::None; static std::deque<FlashInferAttnParams*> PARAMS_CACHE; static int MIN_CACHE_BATCH_SIZE = 256; static int MIN_CACHE_INPUT_TOKEN_NUM = 512; static int MIN_CACHE_PAGE_NUM = 48 * 1024; void FlashInferAttnParamsDel(void* p) { PARAMS_CACHE.push_back((FlashInferAttnParams *)p); } tuple<BufferPtr, vector<torch::Tensor>> FlashInferAttnParams::allocateManyBuffer( CudaDevice *device, const std::vector<std::vector<int64_t>> &shapes, AllocationType atype) { vector<torch::Tensor> tensors; vector<size_t> sizes; size_t total_size = 0; for (const auto &shape : shapes) { size_t size = 1; for (const auto dim : shape) { size *= dim; } size = (size + 31) / 32 * 32; sizes.push_back(size); total_size += size; } auto buf = device->allocateBuffer({DataType::TYPE_INT32, {total_size}, atype}, {"flashinfer_buf"}); auto buf_ptr = buf->data<int>(); auto cuda_option = torch::dtype(torch::kInt).device(torch::DeviceType::CUDA).requires_grad(false); size_t offset = 0; for (size_t i = 0; i < sizes.size(); i++) { tensors.emplace_back(torch::from_blob(buf_ptr + offset, shapes[i], cuda_option)); offset += sizes[i]; } return {buf, tensors}; } FlashInferAttnParams *FlashInferAttnParams::create(CudaDevice *device, int batch_size, int input_token_num, int page_num) { if (!PARAMS_CACHE.empty()) { auto params = PARAMS_CACHE.back(); PARAMS_CACHE.pop_back(); if (batch_size < params->batch_size && input_token_num < params->input_token_num) { return params; } delete params; } auto params = make_unique<FlashInferAttnParams>(); params->batch_size = batch_size; params->input_token_num = input_token_num; params->page_num = page_num; // batch_prefill_tmp_v may use 256M buffer params->float_workspace = device->allocateBuffer({DataType::TYPE_INT8, {(256 + 16) * 1024 * 1024}, AllocationType::DEVICE}, {"float_workspace"}); params->int_workspace = device->allocateBuffer({DataType::TYPE_INT8, {8 * 1024 *1024}, AllocationType::DEVICE}, {"int_workspace"}); params->int_host_workspace = device->allocateBuffer({DataType::TYPE_INT8, {8 * 1024 * 1024}, AllocationType::HOST}, {"int_host_workspace"}); params->float_workspace_d = Buffer2torchTensor(params->float_workspace, false); params->int_workspace_d = Buffer2torchTensor(params->int_workspace, false); params->int_workspace_h = Buffer2torchTensor(params->int_host_workspace, false); #define ALLOC_BUFFER(suffix, type) \ do { \ auto alloc_ret = allocateManyBuffer(device, { \ {batch_size + 1}, /* page_indptr */ \ {batch_size + 1}, /* qo_indptr */ \ {input_token_num}, /* batch_indice */ \ {input_token_num}, /* positions */ \ {batch_size}, /* kv_len */ \ {batch_size}, /* paged_kv_last_page_len */ \ {page_num}}, /* page_indice */ \ type); \ \ params->buf_##suffix = get<0>(alloc_ret); \ auto &tensors = get<1>(alloc_ret); \ params->page_indptr_##suffix = tensors[0]; \ params->qo_indptr_##suffix = tensors[1]; \ params->batch_indice_##suffix = tensors[2]; \ params->positions_##suffix = tensors[3]; \ params->kvlen_##suffix = tensors[4]; \ params->paged_kv_last_page_len_##suffix = tensors[5]; \ params->page_indice_##suffix = tensors[6]; \ } while (0) ALLOC_BUFFER(h, AllocationType::HOST); ALLOC_BUFFER(d, AllocationType::DEVICE); return params.release(); } void FlashInferAttnParams::fillFlashInfer(const BufferPtr &prefix_lengths_host, const BufferPtr &sequence_lengths_host, const BufferPtr &input_lengths_host, const BufferPtr &kv_cache_block_id_host, const int batch_size, const int tokens_per_block) { const int max_batch_blocks = kv_cache_block_id_host ? kv_cache_block_id_host->shape()[1] : -1; RTP_LLM_CHECK_WITH_INFO(batch_size <= this->batch_size, "batch_size exceed reserved %d > %d", batch_size, this->batch_size); auto qo_indptr = qo_indptr_h.data_ptr<int>(); auto page_indptr = page_indptr_h.data_ptr<int>(); auto batch_indice = batch_indice_h.data_ptr<int>(); auto positions = positions_h.data_ptr<int>(); auto paged_kv_last_page_len = paged_kv_last_page_len_h.data_ptr<int>(); auto kvlen = kvlen_h.data_ptr<int>(); auto page_indice = page_indice_h.data_ptr<int>(); auto input_lengths = input_lengths_host->data<int>(); auto sequence_lengths = sequence_lengths_host ? sequence_lengths_host->data<int>() : nullptr; auto prefix_lengths = prefix_lengths_host ? prefix_lengths_host->data<int>() : nullptr; auto kv_cache_block_id = kv_cache_block_id_host ? kv_cache_block_id_host->data<int>() : nullptr; int qo_offset = 0; int offset = 0; int total_page_idx = 0; qo_indptr[0] = 0; page_indptr[0] = 0; for (int i = 0; i < batch_size; i++) { int seq_len = 0; if (prefix_lengths) { int input_length = input_lengths[i]; int prefix_length = prefix_lengths[i]; RTP_LLM_CHECK_WITH_INFO(offset + input_length <= this->input_token_num, "token_num exceed reserved %d > %d", offset + input_length, this->input_token_num); for (int j = 0; j < input_length; j++) { batch_indice[offset] = i; positions[offset] = j + prefix_length; offset += 1; } qo_offset += input_length; seq_len = input_length + prefix_length; } else { batch_indice[i] = i; positions[i] = sequence_lengths[i]; qo_offset += 1; seq_len = sequence_lengths[i] + 1; } paged_kv_last_page_len[i] = (seq_len - 1) % tokens_per_block + 1; kvlen[i] = seq_len; int page_num = (seq_len + tokens_per_block - 1) / tokens_per_block; RTP_LLM_CHECK_WITH_INFO(total_page_idx + page_num <= this->page_num, "page_num exceed reserved %d > %d", total_page_idx + page_num, this->page_num); if (kv_cache_block_id) { for (int j = 0; j < page_num; j++) { auto page_idx = kv_cache_block_id[i * max_batch_blocks + j]; page_indice[total_page_idx++] = page_idx; } } page_indptr[i + 1] = total_page_idx; qo_indptr[i + 1] = qo_offset; } } void FlashInferAttnParams::refreshFlashInferBuf(CudaDevice *device, int batch_size, int input_token_num) { auto stream = device->getStream(); cudaMemcpyAsync(buf_d->data(), buf_h->data(), buf_h->sizeBytes(), cudaMemcpyHostToDevice, stream); vector<int64_t> shape = {batch_size + 1}; #define REFRESH_SHAPE(t) \ do { \ t##_d.unsafeGetTensorImpl()->set_sizes_contiguous(shape); \ t##_h.unsafeGetTensorImpl()->set_sizes_contiguous(shape); \ } while (0) REFRESH_SHAPE(page_indptr); REFRESH_SHAPE(qo_indptr); shape[0] = input_token_num; REFRESH_SHAPE(batch_indice); REFRESH_SHAPE(positions); shape[0] = batch_size; REFRESH_SHAPE(kvlen); REFRESH_SHAPE(paged_kv_last_page_len); } bool FlashInferAttnParams::sameQLength(const BufferPtr &input_lengths_host, int batch_size, int &q_length) { auto input_lengths = input_lengths_host->data<int>(); for (int i = 0; i < batch_size; i++) { int input_length = input_lengths[i]; if (q_length > 0 && q_length != input_length) { return false; } q_length = input_length; } return true; } void FlashInferAttnParams::genPlan(int batch_size, int q_length, int local_head_num, int local_head_num_kv, int size_per_head, int tokens_per_block, int kv_lora_rank, bool use_mla, int64_t stream) { // std::cout << "use_mla: " << use_mla << std::endl // << "mla_type: " << int(mla_ops_type) << std::endl // << "page_indptr: " << page_indptr_d // << "qo_indptr: " << qo_indptr_d // << "batch_indice: " << batch_indice_d // << "positions: " << positions_d // << "kvlen: " << kvlen_d // << "paged_kv_last_page_len: " << paged_kv_last_page_len_d // << "page_indice: " << page_indice_d.index({torch::indexing::Slice(0, 32)}) // << "kv_cache_block_id: " << kv_cache_block_id_d << std::endl; if (use_mla) { if (mla_ops_type == MlaOpsType::FLASH_INFER) { plan = BatchMLAPagedAttentionPlan( float_workspace_d, int_workspace_d, int_workspace_h, qo_indptr_h, page_indptr_h, kvlen_h, local_head_num, kv_lora_rank, true, stream); } else if (mla_ops_type == MlaOpsType::FLASH_MLA) { RTP_LLM_LOG_TRACE("batch_size = %zu", batch_size); RTP_LLM_LOG_TRACE("local_head_num = %zu", local_head_num); flash_mla_plan = get_mla_metadata(kvlen_d, local_head_num * q_length, 1); } else { RTP_LLM_FAIL("unexpected mla ops type: %d", int(mla_ops_type)); } } else { if (decode) { plan = BatchDecodeWithPagedKVCachePlan( float_workspace_d, // float_workspace_buffer int_workspace_d, // int_workspace_buffer int_workspace_h, // page_locked_int_workspace_buffer page_indptr_h, // indptr batch_size, // batch_size local_head_num, // num_qo_heads local_head_num_kv, // num_kv_heads tokens_per_block, // page_size false, // enable_cuda_graph, -1, // window_left -1, // logits_soft_cap size_per_head, // head_dim_qk size_per_head, // head_dim_vo torch::empty(0, dataTypeToTorchType(dtype)), // empty_q_data torch::empty(0, dataTypeToTorchType(dtype)), // empty_kv_data stream); } else { plan = BatchPrefillWithKVCachePlan( float_workspace_d, // float_workspace_buffer int_workspace_d, // int_workspace_buffer int_workspace_h, // page_locked_int_workspace_buffer qo_indptr_h, // qo_indptr page_indptr_h, // kv_indptr torch::empty(0, dataTypeToTorchType(DataType::TYPE_INT32)), // kv_len_arr, not in use yet batch_size, // total_num_rows batch_size, // batch_size local_head_num, // num_qo_heads local_head_num_kv, // num_kv_heads tokens_per_block, // page_size false, // enable_cuda_graph size_per_head, // head_dim_qk size_per_head, // head_dim_vo true, // causal stream); } } } FlashInferAttnParamsPtr FlashInferAttnParams::prepare( rtp_llm::DeviceBase *device, const rtp_llm::AttentionConfigs &attn_configs, const BufferPtr &prefix_lengths_host, const BufferPtr &sequence_lengths_host, const BufferPtr &input_lengths_host, const BufferPtr &kv_cache_block_id_host, const BufferPtr &kv_cache_block_id_device, rtp_llm::DataType dtype) { if (rtp_llm::get_sm() < 80) { return nullptr; } const int batch_size = input_lengths_host->shape()[0]; if (batch_size == 0) { return nullptr; } auto cuda_device = dynamic_cast<CudaDevice*>(device); if (!cuda_device) { return nullptr; } MlaOpsType mla_ops_type = device->mla_ops_type; int q_length = -1; if (mla_ops_type == MlaOpsType::FLASH_MLA && (!sameQLength(input_lengths_host, batch_size, q_length) || q_length == -1 || q_length > 32)) { mla_ops_type = MlaOpsType::FLASH_INFER; } const char* disable_flash_infer_env = getenv("DISABLE_FLASH_INFER"); const bool disable_flash_infer (disable_flash_infer_env && strcmp(disable_flash_infer_env, "1") == 0); if ((!attn_configs.use_mla || mla_ops_type == MlaOpsType::FLASH_INFER) && disable_flash_infer) { return nullptr; } const int local_head_num = attn_configs.head_num; const int local_head_num_kv = attn_configs.kv_head_num; const int size_per_head = attn_configs.size_per_head; const int group_size = local_head_num / local_head_num_kv; const int tokens_per_block = attn_configs.tokens_per_block; // to underlay buffer dtype if (dtype == DataType::TYPE_QFP8_E4M3) { dtype = DataType::TYPE_FP8_E4M3; } if (!attn_configs.use_mla) { if ((dtype != DataType::TYPE_FP16 && dtype != DataType::TYPE_BF16 && dtype != DataType::TYPE_FP8_E4M3) || (attn_configs.kv_cache_dtype != KvCacheDataType::BASE && !(attn_configs.kv_cache_dtype == KvCacheDataType::FP8 && dtype == DataType::TYPE_FP8_E4M3)) || (attn_configs.rope_config.style != RopeStyle::Base && attn_configs.rope_config.style != RopeStyle::No) || attn_configs.mask_type != causalMask || attn_configs.q_scaling != 1.0f || attn_configs.use_logn_attn || (size_per_head != 64 && size_per_head != 128 && size_per_head != 192) || (group_size > 10 && group_size != 16)) { return nullptr; } } int input_token_num = 0; if (prefix_lengths_host) { input_token_num = std::accumulate(input_lengths_host->data<int>(), input_lengths_host->data<int>() + batch_size, 0); } else { input_token_num = input_lengths_host->shape()[0]; } auto params = FlashInferAttnParams::create(cuda_device, max(MIN_CACHE_BATCH_SIZE, batch_size), max(MIN_CACHE_INPUT_TOKEN_NUM, input_token_num), MIN_CACHE_PAGE_NUM); FlashInferAttnParamsPtr ret(params, FlashInferAttnParamsDel); if (kv_cache_block_id_device) { params->kv_cache_block_id_d = Buffer2torchTensor(kv_cache_block_id_device, false); } params->mla_ops_type = mla_ops_type; params->dtype = dtype; params->fillFlashInfer(prefix_lengths_host, sequence_lengths_host, input_lengths_host, kv_cache_block_id_host, batch_size, tokens_per_block); params->refreshFlashInferBuf(cuda_device, batch_size, input_token_num); if (group_size > 5) { params->decode = false; } else { params->decode = true; } params->genPlan(batch_size, q_length, local_head_num, local_head_num_kv, size_per_head, tokens_per_block, attn_configs.kv_lora_rank, attn_configs.use_mla, reinterpret_cast<int64_t>(cuda_device->getStream())); // cuda_stream return ret; } void FlashInferAttnParams::run( const AttentionModuleParams& params, const BufferPtr &f16_out, std::function<void()> moe_insertion_callback, int64_t stream) { const int local_head_num = params.configs.head_num; const int local_head_num_kv = params.configs.kv_head_num; const int size_per_head = params.configs.size_per_head; if (params.weights.qkv_weight->bias) { at::Tensor qkv_input = Buffer2torchTensor(params.input, false); qkv_input.add_(Buffer2torchTensor(params.weights.qkv_weight->bias, false)); } const int bs = params.input.shape()[0]; const vector<int64_t> strides = {(local_head_num + 2 * local_head_num_kv) * size_per_head, size_per_head, 1}; const auto cuda_option = torch::dtype(dataTypeToTorchType(params.input.type())).device(torch::DeviceType::CUDA).requires_grad(false); auto q = torch::from_blob(params.input.data(), {bs, local_head_num, size_per_head}, strides, cuda_option); auto append_k = torch::from_blob(params.input.dataWithOffset(local_head_num * size_per_head), {bs, local_head_num_kv, size_per_head}, strides, cuda_option); apply_rope_pos_ids(q, append_k, q, append_k, positions_d, params.configs.rope_config.dim, false, params.configs.rope_config.scale, params.configs.rope_config.base, stream); sync_check_cuda_error(); auto append_v = torch::from_blob(params.input.dataWithOffset((local_head_num + local_head_num_kv) * size_per_head), {bs, local_head_num_kv, size_per_head}, strides, cuda_option); auto k_cache = Buffer2torchTensor(params.common.kv_cache->k_cache_buffer, false); auto v_cache = Buffer2torchTensor(params.common.kv_cache->v_cache_buffer, false); if (append_k.type() != k_cache.type()) { append_k = append_k.to(k_cache.type()); append_v = append_v.to(k_cache.type()); } append_paged_kv_cache(append_k, append_v, batch_indice_d, positions_d, k_cache, v_cache, page_indice_d, page_indptr_d, paged_kv_last_page_len_d, 1, stream); moe_insertion_callback(); sync_check_cuda_error(); auto softmax_scale = (1.0f / sqrtf(size_per_head * 1.0f)) * params.configs.softmax_extra_scale; at::Tensor out; if (params.output.type() == DataType::TYPE_FP8_E4M3) { out = Buffer2torchTensor(f16_out, false); } else { out = Buffer2torchTensor(params.output, false); } if (decode) { BatchDecodeWithPagedKVCacheRun( float_workspace_d, // float_workspace_buffer int_workspace_d, // int_workspace_buffer plan, // plan_info_vec q, // q k_cache, // paged_k_cache v_cache, // paged_v_cache page_indptr_d, // paged_kv_indptr page_indice_d, // paged_kv_indices paged_kv_last_page_len_d, // paged_kv_last_page_len out, std::nullopt, // maybe_lse 1, // kv_layout_code -1, // window_left std::nullopt, // maybe_alibi_slopes 0, // logits_soft_cap softmax_scale, 0, 0, stream); } else { BatchPrefillWithPagedKVCacheRun( float_workspace_d, // float_workspace_buffer int_workspace_d, // int_workspace_buffer plan, // plan_info_vec q, // q k_cache, // paged_k_cache v_cache, // paged_v_cache qo_indptr_d, // qo_indptr page_indptr_d, // paged_kv_indptr page_indice_d, // paged_kv_indices paged_kv_last_page_len_d, // paged_kv_last_page_len out, std::nullopt, // maybe_lse 1, // mask_mode_code, 1, // layout -1, // window_left std::nullopt, // maybe_custom_mask std::nullopt, // maybe_mask_indptr std::nullopt, // maybe_alibi_slopes 0, // logits_soft_cap softmax_scale, params.configs.rope_config.scale, params.configs.rope_config.base, stream); } const auto &scale = params.weights.static_scale_reciprocal_weight; if (scale) { auto scale_t = Buffer2torchTensor(scale->kernel, false); auto fp8_out = Buffer2torchTensor(params.output, false); fp8_out.copy_((scale_t * out).to(torch::kFloat8_e4m3fn)); } } }