in maga_transformer/cpp/cuda/cublas/cublasFP8MMWrapper.cc [599:758]
void cublasFP8MMWrapper::Gemm_Bias_Act(__nv_bfloat16* res,
int batchCount,
int m,
int n,
int k,
int64_t strideA,
int64_t strideB,
int64_t strideD,
const float* alpha,
const float* beta,
const __nv_fp8_e4m3* input,
const __nv_fp8_e4m3* kernel,
const float* input_scale,
const float* kernel_scale,
const __nv_bfloat16* bias,
const float* output_scale,
cudaStream_t stream)
{
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
mu_->lock();
const void* devAscalePtr = (const void*)kernel_scale;
const void* devBscalePtr = (const void*)input_scale;
const void* devDscalePtr = (const void*)output_scale;
const size_t wsSizeBytes = CUBLAS_WORKSPACE_SIZE;
const auto aType = CUDA_R_8F_E4M3;
const auto bType = CUDA_R_8F_E4M3;
const auto dType = CUDA_R_16BF;
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
// const auto epilogueAuxType = CUDA_R_16BF;
const cublasOperation_t tA = CUBLAS_OP_T;
const cublasOperation_t tB = CUBLAS_OP_N;
//------- init, desc & tensors
cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t Adesc;
cublasLtMatrixLayout_t Bdesc;
cublasLtMatrixLayout_t Ddesc;
{
check_cuda_error(cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType));
check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA)));
check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB)));
if (version_major_ >= 11 && version_minor_ >= 11 && version_patch_ > 0) {
const int8_t fastAccuMode = 1; // enable fast imprecise accum
check_cuda_error(cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode))));
}
// TODO: Check that do we need to set these attributes
// TODO: comment them for compiler first
check_cuda_error(cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &devAscalePtr, sizeof(devAscalePtr)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &devBscalePtr, sizeof(devBscalePtr)));
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
if (RELU == true) {
epi = CUBLASLT_EPILOGUE_RELU_BIAS;
}
else if (GELU == true) {
epi = CUBLASLT_EPILOGUE_GELU_BIAS;
}
// cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*));
}
{
const int64_t lda = k;
const int64_t ldb = k;
const int64_t ldd = n;
// create matrix descriptors, we are good with the details here so no need
// to set any extra attributes
check_cuda_error(
cublasLtMatrixLayoutCreate(&Adesc, aType, tA == CUBLAS_OP_N ? n : k, tA == CUBLAS_OP_N ? k : n, lda));
if (batchCount > 1) {
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA)));
}
check_cuda_error(
cublasLtMatrixLayoutCreate(&Bdesc, bType, tB == CUBLAS_OP_N ? k : m, tB == CUBLAS_OP_N ? m : k, ldb));
if (batchCount > 1) {
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB)));
}
check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, dType, n, m, ldd));
if (batchCount > 1) {
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
check_cuda_error(cublasLtMatrixLayoutSetAttribute(
Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD)));
}
}
const int requestedAlgoCount = 1;
cublasLtMatmulHeuristicResult_t heuristicResult;
cublasLtMatmulPreference_t preference;
int returnedAlgoCount = -1;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &wsSizeBytes, sizeof(wsSizeBytes)));
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle_,
matmulDesc,
Adesc,
Bdesc,
Ddesc,
Ddesc,
preference,
requestedAlgoCount,
&heuristicResult,
&returnedAlgoCount));
{
cublasStatus_t status = cublasLtMatmul(cublaslt_handle_,
matmulDesc,
alpha,
kernel,
Adesc,
input,
Bdesc,
beta,
res,
Ddesc,
res,
Ddesc,
&heuristicResult.algo,
cublas_workspace_,
wsSizeBytes,
stream);
check_cuda_error(status);
}
if (Ddesc) {
check_cuda_error(cublasLtMatrixLayoutDestroy(Ddesc));
}
if (Bdesc) {
check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc));
}
if (Adesc) {
check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc));
}
if (matmulDesc) {
check_cuda_error(cublasLtMatmulDescDestroy(matmulDesc));
}
mu_->unlock();
}