maga_transformer/cpp/cuda/cuda_fmha_utils.h (162 lines of code) (raw):
#pragma once
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/utils/Logger.h"
#include "maga_transformer/cpp/cuda/cuda_utils.h"
#include "maga_transformer/cpp/core/Types.h"
#include "maga_transformer/cpp/th_op/GptInitParameter.h"
#include "3rdparty/trt_fused_multihead_attention/qkvToContext.h"
#include "3rdparty/contextFusedMultiHeadAttention/fmhaRunner.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
namespace rtp_llm {
class CudaFmhaUtils {
public:
template<typename T>
static bool UseTrtFMHA(const rtp_llm::GptInitParameter& gpt_init_parameter) {
bool use_trt_fmha = CheckUseFMHA<T>(gpt_init_parameter) && CheckQKVLengthEqual<T>(gpt_init_parameter);
if (!(is_sm8x() || is_sm90() || is_sm70())) {
RTP_LLM_LOG_INFO("TRT FMHA is disabled for sm %d", get_sm());
use_trt_fmha = false;
}
if (gpt_init_parameter.is_sparse_head_){
RTP_LLM_LOG_INFO("TRT FMHA is disabled for sparse");
use_trt_fmha = false;
}
char* fmha_env = std::getenv("ENABLE_TRT_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_INFO("TRT FMHA is disabled for by env");
use_trt_fmha = false;
}
if (!tensorrt_llm::kernels::MHARunner::fmha_supported(gpt_init_parameter.size_per_head_, rtp_llm::get_sm())) {
RTP_LLM_LOG_INFO("TRT FMHA is disabled for by check fmha_supported");
use_trt_fmha = false;
}
return use_trt_fmha;
}
template<typename T>
static bool UseOldTrtFMHA(const rtp_llm::GptInitParameter& gpt_init_parameter) {
#ifdef USE_OLD_TRT_FMHA
bool use_old_trt_fmha = CheckUseFMHA<T>(gpt_init_parameter) && CheckQKVLengthEqual<T>(gpt_init_parameter);
if (!use_old_trt_fmha) {
return false;
}
if(!std::is_same<T, half>::value){
RTP_LLM_LOG_INFO("OLD TRT FMHA only support half");
return false;
}
if (gpt_init_parameter.head_num_ != gpt_init_parameter.head_num_kv_) {
RTP_LLM_LOG_INFO("OLD TRT not support head_num != head_num_kv");
return false;
}
auto testRunner = FusedMHARunnerFP16v2(gpt_init_parameter.head_num_, gpt_init_parameter.size_per_head_, get_sm(), gpt_init_parameter.q_scaling_);
if (!testRunner.fmha_supported(gpt_init_parameter.is_causal_)) {
RTP_LLM_LOG_INFO("OLD TRT disabled by call fmha_supported");
return false;
}
return true;
#else
RTP_LLM_LOG_INFO("USE_OLD_TRT_FMHA not enabled by define");
return false;
#endif
}
template<typename T>
static bool UsePagedTrtFMHA(const rtp_llm::GptInitParameter& gpt_init_parameter) {
bool use_paged_trt_fmha = CheckUseFMHA<T>(gpt_init_parameter);
if (!(is_sm8x() || is_sm90())) {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for sm %d", get_sm());
use_paged_trt_fmha = false;
}
if (!gpt_init_parameter.use_kvcache_) {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled when not use kvcache");
use_paged_trt_fmha = false;
}
if (gpt_init_parameter.is_sparse_head_) {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for sparse");
use_paged_trt_fmha = false;
}
if (gpt_init_parameter.isKvCacheQuant()) {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for int8 kvcache");
use_paged_trt_fmha = false;
}
char* paged_fmha_env = std::getenv("ENABLE_PAGED_TRT_FMHA");
if (paged_fmha_env && std::string(paged_fmha_env) == "OFF") {
RTP_LLM_LOG_INFO("Paged TRT FMHA is disabled for by ENABLE_PAGED_TRT_FMHA=OFF env");
use_paged_trt_fmha = false;
}
return use_paged_trt_fmha;
}
template<typename T>
static bool UseOpenSourceFMHA(const rtp_llm::GptInitParameter& gpt_init_parameter) {
bool use_open_source_fmha = CheckUseFMHA<T>(gpt_init_parameter) && CheckQKVLengthEqual<T>(gpt_init_parameter);
if (!(is_sm8x() || is_sm90())) {
RTP_LLM_LOG_INFO("opensource FMHA is disabled for sm %d", get_sm());
use_open_source_fmha = false;
}
char* fmha_env = std::getenv("ENABLE_OPENSOURCE_FMHA");
if (fmha_env && std::string(fmha_env) == "OFF") {
RTP_LLM_LOG_INFO("opensource FMHA is disabled for by env");
use_open_source_fmha = false;
}
return use_open_source_fmha;
}
protected:
template<typename T>
static bool CheckUseFMHA(const rtp_llm::GptInitParameter& params) {
char* fmha_env = std::getenv("ENABLE_FMHA");
bool fmha_enable = (fmha_env == nullptr || std::string(fmha_env) != "OFF");
if (!fmha_enable){
RTP_LLM_LOG_INFO("FMHA is not enbaled");
return false;
}
if(std::is_same<T, float>::value){
RTP_LLM_LOG_INFO("FMHA not support float");
return false;
}
return true;
}
template<typename T>
static bool CheckQKVLengthEqual(const rtp_llm::GptInitParameter& params) {
char* reuse_cache_env = std::getenv("REUSE_CACHE");
bool not_prefix =
params.pre_seq_len_ == 0 && (reuse_cache_env == nullptr || std::string(reuse_cache_env) != "1");
char* multi_task_prompt_env = std::getenv("MULTI_TASK_PROMPT");
char* multi_task_prompt_str_env = std::getenv("MULTI_TASK_PROMPT_STR");
char* sp_model_env = std::getenv("SP_MODEL_TYPE");
char* enable_partial_fallback_env = std::getenv("ENABLE_PARTIAL_FALLBACK");
if (enable_partial_fallback_env != nullptr && std::string(enable_partial_fallback_env) == "1") {
RTP_LLM_LOG_INFO("QKV length not equal: enable part fallback");
return false;
}
char* enable_fast_gen_env = std::getenv("ENABLE_FAST_GEN");
if (enable_fast_gen_env != nullptr && std::string(enable_fast_gen_env) == "1") {
RTP_LLM_LOG_INFO("QKV length not equal: enable fast gen");
return false;
}
if (!not_prefix){
RTP_LLM_LOG_INFO("QKV length not equal: use kv cache reuse");
return false;
}
if (sp_model_env != nullptr){
RTP_LLM_LOG_INFO("QKV length not equal: use sp_model");
return false;
}
if (multi_task_prompt_env && strcmp(multi_task_prompt_env, "") != 0) {
RTP_LLM_LOG_INFO("QKV length not equal: use multi_task_prompt");
return false;
}
if (multi_task_prompt_str_env && strcmp(multi_task_prompt_str_env, "") != 0) {
RTP_LLM_LOG_INFO("QKV length not equal: use multi_task_prompt_str");
return false;
}
return true;
}
};
}