void CudaDevice::prefillAttention()

in maga_transformer/cpp/devices/cuda_impl/CudaPrefillAttention.cc [77:274]


void CudaDevice::prefillAttention(const AttentionModuleParams& params,
                            KVBlockArray                 kv_block_array,
                            const BufferPtr&             q_output,
                            const BufferPtr&             k_output,
                            const BufferPtr&             v_output,
                            const BufferPtr&             qkv_buf_fp8) {
    auto fmha_type = fmha_type_;
    auto stream  = stream_;
    auto cufmha_runner = cufmha_runner_;
    RTP_LLM_LOG_DEBUG("FMHA Type use %s.", std::to_string((int)fmha_type).c_str());
    auto      datatype            = params.input.type();
    auto      token_num           = params.input.shape()[0];
    auto      batch_size          = params.common.context_batch_size;
    auto      seq_len             = params.common.context_max_seq_len;
    auto      seq_len_with_prefix = seq_len + params.common.max_prefix_length;
    auto      head_num            = params.configs.head_num;
    auto      kv_head_num         = params.configs.kv_head_num;
    auto      size_per_head       = params.configs.size_per_head;
    bool      use_fp8_fmha        = qkv_buf_fp8 != nullptr;
    BufferPtr tiled_counter_ptr;
    if (FMHAType::PAGED_TRT_V2 == fmha_type || FMHAType::TRT_V2 == fmha_type) {
        tiled_counter_ptr =
            allocateBuffer({DataType::TYPE_UINT32, {1}, AllocationType::DEVICE}, {"tiled_counter_pointer"});
        cudaMemsetAsync(tiled_counter_ptr->data(), 0, sizeof(uint32_t), stream);
    }
    switch (fmha_type) {
        case FMHAType::PAGED_TRT_V2: {
            RTP_LLM_CHECK_WITH_INFO(q_output != nullptr, "q_output must be provided for paged trt v2 fmha");
            cufmha_runner->runTrtV2FmhaPaged(q_output->data(),
                                             params.common.cu_seqlens->data(),
                                             params.common.cu_kv_seqlens->data(),
                                             params.output.data(),
                                             reinterpret_cast<uint32_t*>(tiled_counter_ptr->data()),
                                             batch_size,
                                             seq_len,
                                             seq_len_with_prefix,
                                             token_num,
                                             kv_block_array,
                                             false,
                                             false,
                                             params.common.linear_bias_slopes != nullptr,
                                             false);
            break;
        }
        case FMHAType::TRT_V2: {
            void*  fmha_input_ptr                    = use_fp8_fmha ? qkv_buf_fp8->data() : params.input.data();
            void*  fmha_output_ptr                   = params.output.data();
            RTP_LLM_CHECK_WITH_INFO(fmha_input_ptr, "fmha_input_ptr must be provided for trt v2 fmha");
            float* attention_output_orig_quant_scale = nullptr;
            if (params.weights.static_scale_reciprocal_weight && use_fp8_fmha) {
                printBufferData(*(params.weights.static_scale_reciprocal_weight->kernel), "attn scale");
                attention_output_orig_quant_scale =
                    (params.weights.static_scale_reciprocal_weight->kernel->data<float>());
            }
            bool      need_quant_fmha_out = !use_fp8_fmha && params.output.isQBuffer();
            BufferPtr tmp_fmha_output;
            if (need_quant_fmha_out) {
                // for sm89 cannot use fp8_fmha, but attention output should be fp8
                tmp_fmha_output = allocateBuffer({DataType::TYPE_FP16,
                                                  {batch_size, head_num * seq_len_with_prefix * size_per_head},
                                                  AllocationType::DEVICE},
                                                 {"fmha_fp16_output"});
                cudaMemsetAsync(tmp_fmha_output->data(), 0, tmp_fmha_output->sizeBytes(), stream);
                fmha_output_ptr = tmp_fmha_output->data();
            }
            RTP_LLM_CHECK_WITH_INFO(fmha_output_ptr, "fmha_output_ptr must be provided for trt v2 fmha");
            cufmha_runner->runTrtV2Fmha(fmha_input_ptr,
                                        params.common.cu_seqlens->data(),
                                        fmha_output_ptr,
                                        reinterpret_cast<uint32_t*>(tiled_counter_ptr->data()),
                                        attention_output_orig_quant_scale,
                                        batch_size,
                                        seq_len,
                                        token_num,
                                        kv_block_array,
                                        false,
                                        false,
                                        params.common.linear_bias_slopes != nullptr,
                                        false);
            if (need_quant_fmha_out) {
                DataType quant_out_data_type = DataType::TYPE_FP8_E4M3;
                auto     quant_params =
                    QuantizeParams(*tmp_fmha_output,
                                   quant_out_data_type,
                                   1,
                                   QScheme::Qfp8PerTensor,
                                   std::nullopt,
                                   std::nullopt,
                                   (OptionalConstBufferRef)*params.weights.static_quant_weight->kernel,
                                   (OptionalConstBufferRef)*params.weights.static_scale_reciprocal_weight->kernel);
                auto quant_output = quantize(quant_params);
                cudaMemcpyAsync(
                    params.output.data(), quant_output->data(), params.output.size(), cudaMemcpyDeviceToDevice, stream);
            }
            break;
        }
        case FMHAType::PAGED_OPEN_SOURCE: {
            const size_t max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
            const auto   ws_size              = cufmha_runner->getOpenSourceWorkSpaceSize(
                batch_size, seq_len, max_blocks_per_batch * params.configs.tokens_per_block, true);
            auto ws = allocateBuffer({DataType::TYPE_INT8, {ws_size}, AllocationType::DEVICE},
                                             {"open_source_paged_fmha_ws"});
            cufmha_runner->runOpenSourceFmhaPaged(
                params.input.data(),
                params.common.kv_cache->k_cache_buffer->data(),
                params.common.kv_cache->v_cache_buffer->data(),
                params.output.data(),
                params.common.cu_seqlens->data<int>(),
                params.common.cu_kv_seqlens->data<int>(),
                params.common.kv_cache->kv_cache_block_id->data<int>(),
                batch_size,
                max_blocks_per_batch,
                params.configs.tokens_per_block,
                seq_len,
                ws->data(),
                params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data<float>() : nullptr,
                params.configs.softmax_extra_scale);
            break;
        }
        case FMHAType::OPEN_SOURCE: {
            const auto   ws_size      = cufmha_runner->getOpenSourceWorkSpaceSize(batch_size, seq_len);
            auto         ws           = allocateBuffer({DataType::TYPE_INT8, {ws_size}, AllocationType::DEVICE},
                                                               {"open_source_fmha_ws"});
            const size_t hidden_units = head_num * size_per_head;
            const size_t hidden_units_kv = kv_head_num * size_per_head;
            cufmha_runner->runOpenSourceFmha(
                params.input.data(),
                params.input.dataWithOffset(hidden_units),
                params.input.dataWithOffset(hidden_units + hidden_units_kv),
                params.output.data(),
                params.common.cu_seqlens->data<int>(),
                batch_size,
                seq_len,
                ws->data(),
                params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data<float>() : nullptr,
                params.configs.softmax_extra_scale);
            break;
        }
        case FMHAType::TRT_V1: {
            auto qkv_buf_temp = allocateBuffer(
                {datatype, {token_num, head_num + 2 * kv_head_num, size_per_head}, AllocationType::DEVICE},
                {"qkv_buf_temp"});
            cufmha_runner->runTrtV1Fmha(params.input.data(),
                                        params.common.cu_seqlens->data(),
                                        params.output.data(),
                                        qkv_buf_temp->data(),
                                        batch_size,
                                        seq_len,
                                        token_num);
            break;
        }
        default: {
            RTP_LLM_CHECK_WITH_INFO(q_output && k_output && v_output, "q_output/k_output/v_output must be provided for default context attention implementation");
            q_output->updateShape({batch_size, kv_head_num, (head_num / kv_head_num) * seq_len, size_per_head});
            auto qk_output = gemm({*q_output,
                                           *k_output,
                                           std::nullopt,
                                           nullptr,
                                           DataType::TYPE_FP32,
                                           TransposeOperation::NONE,
                                           TransposeOperation::TRANSPOSE});
            qk_output->updateShape({batch_size, head_num, seq_len, seq_len_with_prefix});
            printBufferData(*qk_output, "qk_output: ");
            float scale = (1.0f / sqrtf(size_per_head * 1.0f)) * params.configs.softmax_extra_scale;
            // TODO(lidongjin): Only support float32(in)\float16(output).
            RUNTIME_ASSERT_OP_ARG(params.common.attention_mask,
                                  "attention_mask must be provided for default context attention implementation");
            auto softmax_qk_output = softmax({std::move(qk_output),
                                                      *params.common.attention_mask,
                                                      std::nullopt,
                                                      scale,
                                                      datatype,
                                                      params.common.linear_bias_slopes ?
                                                          (OptionalConstBufferRef)*params.common.linear_bias_slopes :
                                                          std::nullopt});
            softmax_qk_output->updateShape(
                {batch_size, kv_head_num, (head_num / kv_head_num) * seq_len, seq_len_with_prefix});
            printBufferData(*softmax_qk_output, "softmax_qk_output: ");
            auto qkv_output = gemm({*softmax_qk_output, *v_output});
            qkv_output->updateShape({batch_size, head_num, seq_len, size_per_head});
            printBufferData(*qkv_output, "qkv_output");
            auto& qkv_transpose_output = params.output;
            DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype,
                                             invokeTransposeAttentionOutRemovePadding,
                                             qkv_output->data(),
                                             qkv_transpose_output.data(),
                                             token_num,
                                             batch_size,
                                             seq_len,
                                             head_num,
                                             size_per_head,
                                             params.common.padding_offset->data<int>(),
                                             nullptr,
                                             0,
                                             stream);
        }
    }
}