maga_transformer/cpp/deep_gemm/DeepGemmPlugin.cpp (449 lines of code) (raw):
#include <vector>
#include <map>
#include <algorithm>
#include <torch/torch.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include "maga_transformer/cpp/deep_gemm/utils.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/cuda/cuda_utils.h"
#include "maga_transformer/cpp/core/QBuffer.h"
#include "maga_transformer/cpp/deep_gemm/DeepGemmPlugin.h"
using namespace std;
namespace rtp_llm {
#ifdef ENABLE_FP8
template<uint32_t N, uint32_t K, uint32_t GROUP_NUM, DeepGemmType GEMM_TYPE>
void dispatchBlockNK(__nv_bfloat16* output,
__nv_fp8_e4m3* lhs,
float* lhs_scale,
__nv_fp8_e4m3* rhs,
float* rhs_scale,
int* grouped_layout,
uint32_t m,
uint32_t bm,
uint32_t bn,
uint32_t bk,
uint32_t num_stages,
uint32_t num_tma_multicast,
cudaStream_t stream,
uint32_t num_sms,
uint32_t smem_size);
void runDeepGemm(__nv_bfloat16* output,
__nv_fp8_e4m3* lhs,
float* lhs_scale,
__nv_fp8_e4m3* rhs,
float* rhs_scale,
int* grouped_layout,
uint32_t m,
uint32_t n,
uint32_t k,
uint32_t bm,
uint32_t bn,
uint32_t bk,
uint32_t num_groups,
uint32_t num_stages,
uint32_t num_tma_multicast,
DeepGemmType gemm_type,
cudaStream_t stream,
uint32_t num_sms,
uint32_t smem_size);
inline int DeepGemmPlugin::getNumSms() {
static int num_sms = -1;
if (num_sms != -1) {
return num_sms;
}
cudaDeviceProp properties;
int device_idx;
check_cuda_error(cudaGetDevice(&device_idx));
check_cuda_error(cudaGetDeviceProperties(&properties, device_idx));
num_sms = properties.multiProcessorCount;
RTP_LLM_LOG_INFO("cuda device property has sm num %d", num_sms);
num_sms = autil::EnvUtil::getEnv("DEEP_GEMM_NUM_SM", num_sms);
RTP_LLM_LOG_INFO("deep gemm uses sm num %d", num_sms);
return num_sms;
}
int getMaxSmem() {
static int max_smem_per_block = -1;
if (max_smem_per_block != -1) {
return max_smem_per_block;
}
int device_idx = 0;
check_cuda_error(cudaGetDevice(&device_idx));
check_cuda_error(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx));
return max_smem_per_block;
}
inline int ceil_div(int a, int b) {
RTP_LLM_CHECK_WITH_INFO(b != 0, "division cannot be zero");
return (a + b - 1) / b;
}
inline int getTmaAlignedSize(int x, int data_size) {
int tma_alignment_bytes = 16, alignment;
RTP_LLM_CHECK_WITH_INFO(tma_alignment_bytes % data_size == 0, "TMA alignment bytes 16 must be divisible by data size");
alignment = tma_alignment_bytes / data_size;
return ceil_div(x, alignment) * alignment;
}
inline int fixWaveSaturate(int x, int num_sms) {
return (x == 0)? num_sms: x;
}
inline int getNumWaves(int m, int n, int bm, int bn, int num_groups, int num_sms) {
auto m_w = ceil_div(m, bm), n_w = ceil_div(n, bn);
return ceil_div(m_w * n_w * num_groups, num_sms);
}
inline int getLastWaveUtil(int m, int n, int bm, int bn, int num_groups, int num_sms) {
auto m_w = ceil_div(m, bm), n_w = ceil_div(n, bn);
return fixWaveSaturate(m_w * n_w * num_groups % num_sms, num_sms);
}
inline bool isTmaMulticastLegal(int shape_dim, int block_dim, int num_tma_multicast, int num_sms) {
if (num_tma_multicast == 1) {
return true;
}
return (shape_dim % (block_dim * num_tma_multicast) == 0) && (num_sms % num_tma_multicast) == 0;
}
inline int getSmemSize(int num_stages, int k, int bm, int bn, int bk = 128) {
int smem_d = bm * bn * 2;
int smem_a_per_stage = bm * bk;
int smem_scales_a_per_stage = bm * 4;
int smem_b_per_stage = bn * bk;
int smem_scales_b = ceil_div(k, bk) * 4;
int smem_barrier = num_stages * 8 * 2;
int smem_size = 0;
smem_size += smem_d;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_scales_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
int scaler = (bk % bn == 0)? 1: 2;
smem_size += ceil_div(smem_scales_b * scaler, 8) * 8;
smem_size += smem_barrier;
return smem_size;
}
torch::Tensor getColMajorTmaAlignedTensor(Buffer lhs_scale) {
RTP_LLM_CHECK_WITH_INFO(lhs_scale.dim() == 2 || lhs_scale.dim() == 3, "lhs scale must be dim 2 or 3");
RTP_LLM_CHECK_WITH_INFO(lhs_scale.type() == DataType::TYPE_FP32, "lhs scale must be fp32");
int remove_dim = 0;
if (lhs_scale.dim() == 2) {
remove_dim = 1;
}
size_t g, m, k;
g = remove_dim? 1: lhs_scale.shape()[0];
m = lhs_scale.shape()[1 - remove_dim];
k = lhs_scale.shape()[2 - remove_dim];
int aligned_m = getTmaAlignedSize(m, lhs_scale.typeSize());
auto col_major_lhs_scale = torch::transpose(torch::empty({int(g), int(k), int(aligned_m)}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), 1, 2);
col_major_lhs_scale.index_put_(
{torch::indexing::Slice(), torch::indexing::Slice(0, m), torch::indexing::Slice()},
torch::from_blob(
lhs_scale.data(), {int(g), int(m), int(k)}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA))
);
if (remove_dim) {
return col_major_lhs_scale.squeeze(0);
} else {
return col_major_lhs_scale;
}
}
class DeepGemmConfig {
public:
uint32_t block_m, block_n, num_stages, num_tma_multicast, smem_size;
DeepGemmConfig(uint32_t block_m, uint32_t block_n, uint32_t num_stages, uint32_t num_tma_multicast, uint32_t smem_size):
block_m(block_m),
block_n(block_n),
num_stages(num_stages),
num_tma_multicast(num_tma_multicast),
smem_size(smem_size){}
};
DeepGemmConfig getBestConfig(int m, int n, int k, int num_groups, int num_sms, bool is_grouped_contiguous = false) {
static unordered_map<uint64_t, DeepGemmConfig> best_configs;
uint64_t key = ((uint64_t)m << 44) | ((uint64_t)(n & 0xffff) << 28) | ((uint64_t)(k & 0xffff) << 12) | ((uint64_t)num_sms << 4) | ((uint64_t)is_grouped_contiguous);
auto it = best_configs.find(key);
if (it != best_configs.end()) {
return it->second;
}
int block_m;
if (!is_grouped_contiguous && m <= 64) {
block_m = 64;
} else {
block_m = 128;
}
int best_block_m = -1, best_block_n = -1;
for (int block_n: std::vector<int>({16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128})) {
bool success = false;
if (best_block_m == -1 || best_block_n == -1) {
success = true;
} else {
int num_waves = getNumWaves(m, n, block_m, block_n, num_groups, num_sms);
int best_num_waves = getNumWaves(m, n, best_block_m, best_block_n, num_groups, num_sms);
if (num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
int util = getLastWaveUtil(m, n, block_m, block_n, num_groups, num_sms);
int best_util = getLastWaveUtil(m, n, best_block_m, best_block_n, num_groups, num_sms);
success = tie(util, block_m, best_block_n) > tie(best_util, best_block_m, block_n);
}
}
if (success) {
best_block_m = block_m; best_block_n = block_n;
}
}
RTP_LLM_CHECK_WITH_INFO(best_block_m != -1, "block m size cannot be None in best config");
RTP_LLM_CHECK_WITH_INFO(best_block_n != -1, "block n size cannot be None in best config");
int best_num_stages = -1, best_smem_size = -1;
const int sm90_capacitty = getMaxSmem();
vector<int> num_stages_vec;
if (128 % best_block_n) {
num_stages_vec = vector<int>({6, 5, 4});
} else {
num_stages_vec = vector<int>({8, 7, 6, 5, 4});
}
for (auto& num_stages: num_stages_vec) {
best_smem_size = getSmemSize(num_stages, k, best_block_m, best_block_n);
if (best_smem_size <= sm90_capacitty) {
best_num_stages = num_stages;
break;
}
}
RTP_LLM_CHECK_WITH_INFO(best_num_stages != -1, "stages num cannot be None in best config");
int best_num_tma_multicast = 1;
if (m >= 1024 && isTmaMulticastLegal(n, best_block_n, 2, num_sms) && num_groups == 1) {
best_num_tma_multicast = 2;
}
DeepGemmConfig value = DeepGemmConfig(best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size);
best_configs.emplace(key, value);
return value;
}
#define DISPATCH_DEEP_GEMM_NORMAL(N, K, GROUP_NUM) \
if (n == N && k == K && num_groups == GROUP_NUM && gemm_type == DeepGemmType::Normal) { \
dispatchBlockNK<N, K, GROUP_NUM, DeepGemmType::Normal>(output, \
lhs, \
lhs_scale, \
rhs, \
rhs_scale, \
grouped_layout, \
m, \
bm, \
bn, \
bk, \
num_stages, \
num_tma_multicast, \
stream, \
num_sms, \
smem_size); \
return; \
}
#define DISPATCH_DEEP_GEMM_MOE(N, K, GROUP_NUM) \
if (n == N && k == K && num_groups == GROUP_NUM && gemm_type == DeepGemmType::GroupedContiguous) { \
dispatchBlockNK<N, K, GROUP_NUM, DeepGemmType::GroupedContiguous>(output, \
lhs, \
lhs_scale, \
rhs, \
rhs_scale, \
grouped_layout, \
m, \
bm, \
bn, \
bk, \
num_stages, \
num_tma_multicast, \
stream, \
num_sms, \
smem_size); \
return; \
} \
if (n == N && k == K && num_groups == GROUP_NUM && gemm_type == DeepGemmType::GroupedMasked) { \
dispatchBlockNK<N, K, GROUP_NUM, DeepGemmType::GroupedMasked>(output, \
lhs, \
lhs_scale, \
rhs, \
rhs_scale, \
grouped_layout, \
m, \
bm, \
bn, \
bk, \
num_stages, \
num_tma_multicast, \
stream, \
num_sms, \
smem_size); \
return; \
}
void runDeepGemm(__nv_bfloat16* output,
__nv_fp8_e4m3* lhs,
float* lhs_scale,
__nv_fp8_e4m3* rhs,
float* rhs_scale,
int* grouped_layout,
uint32_t m,
uint32_t n,
uint32_t k,
uint32_t bm,
uint32_t bn,
uint32_t bk,
uint32_t num_groups,
uint32_t num_stages,
uint32_t num_tma_multicast,
DeepGemmType gemm_type,
cudaStream_t stream,
uint32_t num_sms,
uint32_t smem_size)
{
RTP_LLM_LOG_DEBUG("m:%u, n:%u, k:%u , bm:%u, bn:%u, bk:%u, num_groups:%u, num_stages:%u, num_tma_multicast:%u\n", m, n, k, bm, bn, bk, num_groups, num_stages, num_tma_multicast);
/*
Deepseek Normal Gemm
*/
DISPATCH_DEEP_GEMM_NORMAL(2112, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(4096, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(7168, 2048, 1)
DISPATCH_DEEP_GEMM_NORMAL(2048, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(16384, 512, 1)
DISPATCH_DEEP_GEMM_NORMAL(24576, 1536, 1)
DISPATCH_DEEP_GEMM_NORMAL(7168, 16384, 1)
DISPATCH_DEEP_GEMM_NORMAL(18432, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(36864, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(7168, 18432, 1)
// tp 8
DISPATCH_DEEP_GEMM_NORMAL(3072, 1536, 1)
DISPATCH_DEEP_GEMM_NORMAL(2048, 512, 1)
DISPATCH_DEEP_GEMM_NORMAL(2304, 7168, 1)
DISPATCH_DEEP_GEMM_NORMAL(7168, 2304, 1)
// Grouped Contiguous
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 256)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 256)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 256)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 128)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 128)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 128)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 8)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 8)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 8)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 9)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 9)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 9)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 10)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 10)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 10)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 16)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 16)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 16)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 64)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 64)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 64)
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 32)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 32)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 32)
// EP 128
DISPATCH_DEEP_GEMM_MOE(4096, 7168, 2)
DISPATCH_DEEP_GEMM_MOE(7168, 4096, 2)
DISPATCH_DEEP_GEMM_MOE(7168, 2048, 2)
/*
QWEN3
*/
// tp1
DISPATCH_DEEP_GEMM_NORMAL(9216, 4096, 1)
DISPATCH_DEEP_GEMM_NORMAL(4096, 8192, 1)
// tp2
DISPATCH_DEEP_GEMM_NORMAL(4608, 4096, 1)
DISPATCH_DEEP_GEMM_NORMAL(4096, 4096, 1)
// tp4
DISPATCH_DEEP_GEMM_NORMAL(2304, 4096, 1)
DISPATCH_DEEP_GEMM_NORMAL(4096, 2048, 1)
// moe ep1
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 128)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 128)
// moe ep2
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 64)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 64)
// moe ep4
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 32)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 32)
// moe ep8
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 16)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 16)
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 16)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 16)
// moe ep 32
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 5)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 5)
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 4)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 4)
// moe ep 64
DISPATCH_DEEP_GEMM_MOE(3072, 4096, 2)
DISPATCH_DEEP_GEMM_MOE(4096, 1536, 2)
RTP_LLM_FAIL("DISPATCH_DEEP_GEMM(N=%u, K=%u, NUM_GROUPS=%u, GEMM_TYPE=%u) no template found", n, k, num_groups, gemm_type);
}
#endif
void DeepGemmPlugin::gemmFp8(const Buffer &lhs, const Buffer &rhs, Buffer &output, cudaStream_t stream) {
#ifdef ENABLE_FP8
// lhs.fp8 e4m3, [m, k]; scales -> fp32, [m, k / 128]
// rhs.fp8 e4m3, [n, k]; scales -> fp32, [n / 128, k / 128]
// output.bf16, [m, n]
size_t m, n, k;
m = lhs.shape()[0]; k = lhs.shape()[1]; n = rhs.size() / k;
RTP_LLM_CHECK_WITH_INFO(n % 64 == 0 && k % 128 == 0, "n(%d) % 64 or k(%d) % 128 != 0", n, k);
RTP_LLM_LOG_DEBUG("lhs:%s, scale:%s, rhs:%s, scale:%s out:%s",
lhs.debugString().c_str(),
reinterpret_cast<const QBuffer&>(lhs).scales().debugString().c_str(),
rhs.debugString().c_str(),
reinterpret_cast<const QBuffer&>(rhs).scales().debugString().c_str(),
output.debugString().c_str());
int num_sms = getNumSms();
auto best_config = getBestConfig(m, n, k, 1, num_sms);
runDeepGemm(output.data<__nv_bfloat16>(),
reinterpret_cast<const QBuffer&>(lhs).kernel().data<__nv_fp8_e4m3>(),
reinterpret_cast<const QBuffer&>(lhs).scales().data<float>(),
reinterpret_cast<const QBuffer&>(rhs).kernel().data<__nv_fp8_e4m3>(),
reinterpret_cast<const QBuffer&>(rhs).scalesData<float>(),
nullptr, // grouped_layout
m,
n,
k,
best_config.block_m,
best_config.block_n,
128, // block_k
1, // num_groups
best_config.num_stages,
best_config.num_tma_multicast,
DeepGemmType::Normal,
stream,
num_sms,
best_config.smem_size);
#endif
}
void DeepGemmPlugin::groupedGemmFp8Contiguous(const Buffer &lhs, const Buffer &rhs, Buffer &output, const Buffer &m_indices, cudaStream_t stream) {
#ifdef ENABLE_FP8
// lhs.fp8 e4m3, [m_sum, k]; scales -> fp32, [m_sum, k / 128]
// rhs.fp8 e4m3, [num_groups, n, k]; scales -> fp32, [num_groups, n / 128, k / 128]
// output.bf16, [m_sum, n]
// m_indices -> int32, [m_sum]
size_t m, n, k;
m = lhs.shape()[0]; k = lhs.shape()[1]; n = rhs.shape()[1];
int num_groups = rhs.shape()[0];
RTP_LLM_CHECK_WITH_INFO(n % 64 == 0 && k % 128 == 0, "n(%d) % 64 or k(%d) % 128 != 0", n, k);
auto lhs_scales = getColMajorTmaAlignedTensor(reinterpret_cast<const QBuffer&>(lhs).scales());
int num_sms = getNumSms();
auto best_config = getBestConfig(m, n, k, 1, num_sms, true);
runDeepGemm(output.data<__nv_bfloat16>(),
reinterpret_cast<const QBuffer&>(lhs).kernel().data<__nv_fp8_e4m3>(),
(float*)lhs_scales.data_ptr(),
reinterpret_cast<const QBuffer&>(rhs).kernel().data<__nv_fp8_e4m3>(),
reinterpret_cast<const QBuffer&>(rhs).scalesData<float>(),
m_indices.data<int>(), // grouped_layout
m,
n,
k,
best_config.block_m,
best_config.block_n,
128, // block_k
num_groups, // num_groups
best_config.num_stages,
best_config.num_tma_multicast,
DeepGemmType::GroupedContiguous,
stream,
num_sms,
best_config.smem_size);
#endif
}
void DeepGemmPlugin::groupedGemmFp8Masked(const Buffer &lhs, const Buffer &rhs, Buffer &output, const Buffer &masked_m, int expected_m, cudaStream_t stream) {
#ifdef ENABLE_FP8
// lhs.fp8 e4m3, [num_groups, m_max, k]; scales -> fp32, [num_groups, m_max, k / 128]
// rhs.fp8 e4m3, [num_groups, n, k]; scales -> fp32, [num_groups, n / 128, k / 128]
// output.bf16, [m, n]
// masked_m -> int32, [num_groups]
size_t m, n, k;
m = lhs.shape()[1]; k = lhs.shape()[2]; n = rhs.shape()[1];
int num_groups = rhs.shape()[0];
RTP_LLM_CHECK_WITH_INFO(n % 64 == 0 && k % 128 == 0, "n(%ld) % 64 or k(%ld) % 128 != 0", n, k);
int num_sms = getNumSms();
auto best_config = getBestConfig(m, n, k, num_groups, num_sms);
runDeepGemm(output.data<__nv_bfloat16>(),
reinterpret_cast<const QBuffer&>(lhs).kernel().data<__nv_fp8_e4m3>(),
reinterpret_cast<const QBuffer&>(lhs).scalesData<float>(),
reinterpret_cast<const QBuffer&>(rhs).kernel().data<__nv_fp8_e4m3>(),
reinterpret_cast<const QBuffer&>(rhs).scalesData<float>(),
masked_m.data<int>(), // grouped_layout
m,
n,
k,
best_config.block_m,
best_config.block_n,
128, // block_k
num_groups, // num_groups
best_config.num_stages,
best_config.num_tma_multicast,
DeepGemmType::GroupedMasked,
stream,
num_sms,
best_config.smem_size);
#endif
}
} // namespace rtp_llm