maga_transformer/cpp/cuda/cublas/cublasFP8MMWrapper.cc (882 lines of code) (raw):

/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cublasFP8MMWrapper.h" namespace rtp_llm { #define CUBLAS_WORKSPACE_1MB 1048576 cublasFP8MMWrapper::cublasFP8MMWrapper(cublasLtHandle_t cublaslt_handle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map, std::mutex* mu, IAllocator* allocator): cublasMMWrapper(nullptr, cublaslt_handle, stream, cublas_algo_map, mu, allocator) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); RTP_LLM_CHECK_WITH_INFO(allocator != nullptr, "must pass allocator to cublasFP8MMWrapper"); cublasVersionCheck(); if (allocator_ != nullptr) { cublas_workspace_qgemm_ = allocator_->reMalloc(cublas_workspace_qgemm_, CUBLAS_WORKSPACE_1MB); } } cublasFP8MMWrapper::cublasFP8MMWrapper(cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, cudaStream_t stream, cublasAlgoMap* cublas_algo_map, std::mutex* mu, IAllocator* allocator): cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, mu, allocator) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); RTP_LLM_CHECK_WITH_INFO(allocator != nullptr, "must pass allocator to cublasFP8MMWrapper"); cublasVersionCheck(); if (allocator_ != nullptr) { cublas_workspace_qgemm_ = allocator_->reMalloc(cublas_workspace_qgemm_, CUBLAS_WORKSPACE_1MB); } } cublasFP8MMWrapper::~cublasFP8MMWrapper() { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_ = nullptr; if (allocator_ != nullptr) { allocator_->free((void**)(&cublas_workspace_qgemm_)); } } cublasFP8MMWrapper::cublasFP8MMWrapper(const cublasFP8MMWrapper& wrapper): cublasMMWrapper(wrapper.cublas_handle_, wrapper.cublaslt_handle_, wrapper.stream_, wrapper.cublas_algo_map_, wrapper.mu_, wrapper.allocator_) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); cublasVersionCheck(); } void cublasFP8MMWrapper::cublasVersionCheck() { cublasGetProperty(MAJOR_VERSION, &version_major_); cublasGetProperty(MINOR_VERSION, &version_minor_); cublasGetProperty(PATCH_LEVEL, &version_patch_); size_t cublasVersion = (version_major_ * 10000 + version_minor_ * 100 + version_patch_); #if defined(FP8_MHA) || !defined(FP8_GEMM_OUTPUT_QUANT_DISABLE) RTP_LLM_CHECK_WITH_INFO((version_major_ > 11) || (version_major_ == 11 && version_minor_ == 11 && version_patch_ >= 4), "FP8 MHA needs d-scale, which is only supported after cublas 11.11.4 !"); #endif } void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale) { Gemm(res, batchCount, m, n, k, strideA, strideB, strideD, alpha, beta, input, kernel, input_scale, kernel_scale, (cudaStream_t)0); } void cublasFP8MMWrapper::Gemm(__nv_bfloat16* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, cudaStream_t stream, bool fastAccum) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); const void* devAscalePtr = (const void*)kernel_scale; const void* devBscalePtr = (const void*)input_scale; const size_t wsSizeBytes = CUBLAS_WORKSPACE_SIZE; const auto aType = CUDA_R_8F_E4M3; const auto bType = CUDA_R_8F_E4M3; const auto dType = CUDA_R_16BF; const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; // const auto epilogueAuxType = CUDA_R_16BF; const cublasOperation_t tA = CUBLAS_OP_T; const cublasOperation_t tB = CUBLAS_OP_N; //------- init, desc & tensors cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t Adesc; cublasLtMatrixLayout_t Bdesc; cublasLtMatrixLayout_t Ddesc; { check_cuda_error(cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType)); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA))); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB))); if (version_major_ >= 11 && version_minor_ >= 11 && version_patch_ > 0 && fastAccum) { const int8_t fastAccuMode = 1; // enable fast imprecise accum check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)))); } // TODO: Check that do we need to set these attributes // TODO: comment them for compiler first check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &devAscalePtr, sizeof(devAscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &devBscalePtr, sizeof(devBscalePtr))); } { const int64_t lda = k; const int64_t ldb = k; const int64_t ldd = n; // create matrix descriptors, we are good with the details here so no need // to set any extra attributes check_cuda_error( cublasLtMatrixLayoutCreate(&Adesc, aType, tA == CUBLAS_OP_N ? n : k, tA == CUBLAS_OP_N ? k : n, lda)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA))); } check_cuda_error( cublasLtMatrixLayoutCreate(&Bdesc, bType, tB == CUBLAS_OP_N ? k : m, tB == CUBLAS_OP_N ? m : k, ldb)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, dType, n, m, ldd)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } } bool findAlgo = cublas_algo_map_->isExist(batchCount, n, m, k, FP8_DATATYPE); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batchCount, n, m, k, FP8_DATATYPE); if (info.stages == -1) { findAlgo = false; } cublasLtMatmulAlgo_t algo; int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; if (findAlgo) { if (info.workspaceSize > workspaceSize) { findAlgo = false; } else { cublasLtMatmulAlgoInit( cublaslt_handle_, computeType, scaleType, aType, bType, dType, dType, info.algoId, &algo); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); #endif #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, &(info.cluster_shapeId), sizeof(info.cluster_shapeId)); #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)); #endif } } { cublasStatus_t status = cublasLtMatmul(cublaslt_handle_, matmulDesc, alpha, kernel, Adesc, input, Bdesc, beta, nullptr, // Cptr, not used here Ddesc, res, Ddesc, (findAlgo ? (&algo) : NULL), cublas_workspace_, wsSizeBytes, stream); check_cuda_error(status); } if (Ddesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Ddesc)); } if (Bdesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc)); } if (Adesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc)); } if (matmulDesc) { check_cuda_error(cublasLtMatmulDescDestroy(matmulDesc)); } mu_->unlock(); } void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const float* output_scale) { Gemm(res, batchCount, m, n, k, strideA, strideB, strideD, alpha, beta, input, kernel, input_scale, kernel_scale, output_scale, 0); } void cublasFP8MMWrapper::Gemm(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const float* output_scale, cudaStream_t stream, bool fastAccum) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); const void* devAscalePtr = (const void*)kernel_scale; const void* devBscalePtr = (const void*)input_scale; const void* devDscalePtr = (const void*)output_scale; RTP_LLM_CHECK(cublas_workspace_ != nullptr); const size_t wsSizeBytes = CUBLAS_WORKSPACE_SIZE; const auto aType = CUDA_R_8F_E4M3; const auto bType = CUDA_R_8F_E4M3; const auto cType = CUDA_R_16BF; const auto dType = CUDA_R_8F_E4M3; const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; const cublasOperation_t tA = CUBLAS_OP_T; const cublasOperation_t tB = CUBLAS_OP_N; //------- init, desc & tensors cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t Adesc; cublasLtMatrixLayout_t Bdesc; cublasLtMatrixLayout_t Cdesc; cublasLtMatrixLayout_t Ddesc; { check_cuda_error(cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType)); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA))); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB))); if (version_major_ >= 11 && version_minor_ >= 11 && version_patch_ > 0 && fastAccum) { const int8_t fastAccuMode = 1; // enable fast imprecise accum check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)))); } // TODO: Check that do we need to set these attributes // TODO: comment them for compiler first check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &devAscalePtr, sizeof(devAscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &devBscalePtr, sizeof(devBscalePtr))); // check_cuda_error(cublasLtMatmulDescSetAttribute( // matmulDesc, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &devDscalePtr, sizeof(devDscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &devDscalePtr, sizeof(devDscalePtr))); } { const int64_t lda = k; const int64_t ldb = k; const int64_t ldd = n; // create matrix descriptors, we are good with the details here so no need // to set any extra attributes check_cuda_error( cublasLtMatrixLayoutCreate(&Adesc, aType, tA == CUBLAS_OP_N ? n : k, tA == CUBLAS_OP_N ? k : n, lda)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA))); } check_cuda_error( cublasLtMatrixLayoutCreate(&Bdesc, bType, tB == CUBLAS_OP_N ? k : m, tB == CUBLAS_OP_N ? m : k, ldb)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, cType, n, m, ldd)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, dType, n, m, ldd)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } } bool findAlgo = cublas_algo_map_->isExist(batchCount, n, m, k, FP8_DATATYPE); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batchCount, n, m, k, FP8_DATATYPE); if (info.stages == -1) { findAlgo = false; } cublasLtMatmulAlgo_t algo; int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; if (findAlgo) { if (info.workspaceSize > workspaceSize) { findAlgo = false; } else { cublasLtMatmulAlgoInit( cublaslt_handle_, computeType, scaleType, aType, bType, cType, dType, info.algoId, &algo); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)); cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme), sizeof(info.reductionScheme)); #if (CUDART_VERSION >= 11000) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)); #endif #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, &(info.cluster_shapeId), sizeof(info.cluster_shapeId)); #elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)); cublasLtMatmulAlgoConfigSetAttribute( &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)); #endif } } { cublasStatus_t status = cublasLtMatmul(cublaslt_handle_, matmulDesc, alpha, kernel, Adesc, input, Bdesc, beta, nullptr, // Cptr, not used here Cdesc, res, Ddesc, (findAlgo ? (&algo) : NULL), cublas_workspace_, wsSizeBytes, stream); check_cuda_error(status); } if (Ddesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Ddesc)); } if (Cdesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Cdesc)); } if (Bdesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc)); } if (Adesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc)); } if (matmulDesc) { check_cuda_error(cublasLtMatmulDescDestroy(matmulDesc)); } mu_->unlock(); } template<bool RELU, bool GELU> void cublasFP8MMWrapper::Conv1x1Gemm(__nv_fp8_e4m3* res, int m, int n, int k, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const __nv_bfloat16* bias, const float input_scale, const float kernel_scale, const float output_scale, cudaStream_t stream) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); size_t workspace_size = 0; // get workspace size qgmmaLauncher.getWorkSpaceSize<RELU, GELU>(n, workspace_size); if (workspace_size > CUBLAS_WORKSPACE_1MB) { throw std::runtime_error("Need to rellocate workspace for qgemm. It is not supported"); // cublas_workspace_qgemm_ = allocator_->reMalloc(cublas_workspace_qgemm_, workspace_size); } qgmmaLauncher.invokeQgmma1x1<RELU, GELU>( res, m, n, k, input, kernel, bias, input_scale, kernel_scale, output_scale, cublas_workspace_qgemm_, stream); sync_check_cuda_error(); mu_->unlock(); } template void cublasFP8MMWrapper::Conv1x1Gemm<true, false>(__nv_fp8_e4m3* res, int m, int n, int k, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const __nv_bfloat16* bias, const float input_scale, const float kernel_scale, const float output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Conv1x1Gemm<true, true>(__nv_fp8_e4m3* res, int m, int n, int k, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const __nv_bfloat16* bias, const float input_scale, const float kernel_scale, const float output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Conv1x1Gemm<false, false>(__nv_fp8_e4m3* res, int m, int n, int k, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const __nv_bfloat16* bias, const float input_scale, const float kernel_scale, const float output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Conv1x1Gemm<false, true>(__nv_fp8_e4m3* res, int m, int n, int k, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const __nv_bfloat16* bias, const float input_scale, const float kernel_scale, const float output_scale, cudaStream_t stream); template<bool RELU, bool GELU> void cublasFP8MMWrapper::Gemm_Bias_Act(__nv_bfloat16* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); const void* devAscalePtr = (const void*)kernel_scale; const void* devBscalePtr = (const void*)input_scale; const void* devDscalePtr = (const void*)output_scale; const size_t wsSizeBytes = CUBLAS_WORKSPACE_SIZE; const auto aType = CUDA_R_8F_E4M3; const auto bType = CUDA_R_8F_E4M3; const auto dType = CUDA_R_16BF; const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; // const auto epilogueAuxType = CUDA_R_16BF; const cublasOperation_t tA = CUBLAS_OP_T; const cublasOperation_t tB = CUBLAS_OP_N; //------- init, desc & tensors cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t Adesc; cublasLtMatrixLayout_t Bdesc; cublasLtMatrixLayout_t Ddesc; { check_cuda_error(cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType)); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA))); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB))); if (version_major_ >= 11 && version_minor_ >= 11 && version_patch_ > 0) { const int8_t fastAccuMode = 1; // enable fast imprecise accum check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)))); } // TODO: Check that do we need to set these attributes // TODO: comment them for compiler first check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &devAscalePtr, sizeof(devAscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &devBscalePtr, sizeof(devBscalePtr))); cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; if (RELU == true) { epi = CUBLASLT_EPILOGUE_RELU_BIAS; } else if (GELU == true) { epi = CUBLASLT_EPILOGUE_GELU_BIAS; } // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); } { const int64_t lda = k; const int64_t ldb = k; const int64_t ldd = n; // create matrix descriptors, we are good with the details here so no need // to set any extra attributes check_cuda_error( cublasLtMatrixLayoutCreate(&Adesc, aType, tA == CUBLAS_OP_N ? n : k, tA == CUBLAS_OP_N ? k : n, lda)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA))); } check_cuda_error( cublasLtMatrixLayoutCreate(&Bdesc, bType, tB == CUBLAS_OP_N ? k : m, tB == CUBLAS_OP_N ? m : k, ldb)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, dType, n, m, ldd)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } } const int requestedAlgoCount = 1; cublasLtMatmulHeuristicResult_t heuristicResult; cublasLtMatmulPreference_t preference; int returnedAlgoCount = -1; check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); check_cuda_error(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &wsSizeBytes, sizeof(wsSizeBytes))); check_cuda_error(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle_, matmulDesc, Adesc, Bdesc, Ddesc, Ddesc, preference, requestedAlgoCount, &heuristicResult, &returnedAlgoCount)); { cublasStatus_t status = cublasLtMatmul(cublaslt_handle_, matmulDesc, alpha, kernel, Adesc, input, Bdesc, beta, res, Ddesc, res, Ddesc, &heuristicResult.algo, cublas_workspace_, wsSizeBytes, stream); check_cuda_error(status); } if (Ddesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Ddesc)); } if (Bdesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc)); } if (Adesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc)); } if (matmulDesc) { check_cuda_error(cublasLtMatmulDescDestroy(matmulDesc)); } mu_->unlock(); } template<bool RELU, bool GELU> void cublasFP8MMWrapper::Gemm_Bias_Act(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); mu_->lock(); const void* devAscalePtr = (const void*)kernel_scale; const void* devBscalePtr = (const void*)input_scale; const void* devDscalePtr = (const void*)output_scale; const size_t wsSizeBytes = CUBLAS_WORKSPACE_SIZE; const auto aType = CUDA_R_8F_E4M3; const auto bType = CUDA_R_8F_E4M3; const auto cType = CUDA_R_16BF; const auto dType = CUDA_R_8F_E4M3; const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; // const auto epilogueAuxType = CUDA_R_16BF; const cublasOperation_t tA = CUBLAS_OP_T; const cublasOperation_t tB = CUBLAS_OP_N; //------- init, desc & tensors cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t Adesc; cublasLtMatrixLayout_t Bdesc; cublasLtMatrixLayout_t Cdesc; cublasLtMatrixLayout_t Ddesc; { check_cuda_error(cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType)); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &tA, sizeof(tA))); check_cuda_error(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &tB, sizeof(tB))); if (version_major_ >= 11 && version_minor_ >= 11 && version_patch_ > 0) { const int8_t fastAccuMode = 1; // enable fast imprecise accum check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(decltype(fastAccuMode)))); } // TODO: Check that do we need to set these attributes // TODO: comment them for compiler first check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &devAscalePtr, sizeof(devAscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &devBscalePtr, sizeof(devBscalePtr))); check_cuda_error(cublasLtMatmulDescSetAttribute( matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &devDscalePtr, sizeof(devDscalePtr))); cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_GELU_BIAS; // cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)); } { const int64_t lda = k; const int64_t ldb = k; const int64_t ldd = n; // create matrix descriptors, we are good with the details here so no need // to set any extra attributes check_cuda_error( cublasLtMatrixLayoutCreate(&Adesc, aType, tA == CUBLAS_OP_N ? n : k, tA == CUBLAS_OP_N ? k : n, lda)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA))); } check_cuda_error( cublasLtMatrixLayoutCreate(&Bdesc, bType, tB == CUBLAS_OP_N ? k : m, tB == CUBLAS_OP_N ? m : k, ldb)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, cType, n, m, ldd)); // (TODO Hongbinl)Not sure if the implementation makes sense if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, dType, n, m, ldd)); if (batchCount > 1) { check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount))); check_cuda_error(cublasLtMatrixLayoutSetAttribute( Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(strideD))); } } const int requestedAlgoCount = 1; cublasLtMatmulHeuristicResult_t heuristicResult; cublasLtMatmulPreference_t preference; int returnedAlgoCount = -1; check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); check_cuda_error(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &wsSizeBytes, sizeof(wsSizeBytes))); #if (CUBLAS_VERSION) <= 12000 uint32_t pointer_mode_mask = 0; check_cuda_error(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); #endif check_cuda_error(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle_, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, &heuristicResult, &returnedAlgoCount)); { cublasStatus_t status = cublasLtMatmul(cublaslt_handle_, matmulDesc, alpha, kernel, Adesc, input, Bdesc, beta, res, Cdesc, res, Ddesc, &heuristicResult.algo, cublas_workspace_, wsSizeBytes, stream); check_cuda_error(status); } if (Ddesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Ddesc)); } if (Bdesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Bdesc)); } if (Adesc) { check_cuda_error(cublasLtMatrixLayoutDestroy(Adesc)); } if (matmulDesc) { check_cuda_error(cublasLtMatmulDescDestroy(matmulDesc)); } mu_->unlock(); } template void cublasFP8MMWrapper::Gemm_Bias_Act<false, true>(__nv_bfloat16* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Gemm_Bias_Act<false, true>(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Gemm_Bias_Act<true, false>(__nv_bfloat16* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Gemm_Bias_Act<true, false>(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream); template void cublasFP8MMWrapper::Gemm_Bias_Act<false, false>(__nv_fp8_e4m3* res, int batchCount, int m, int n, int k, int64_t strideA, int64_t strideB, int64_t strideD, const float* alpha, const float* beta, const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* kernel, const float* input_scale, const float* kernel_scale, const __nv_bfloat16* bias, const float* output_scale, cudaStream_t stream); } // namespace rtp_llm