maga_transformer/cpp/cuda/cublas/cublasMMWrapper.cc (912 lines of code) (raw):
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cublasMMWrapper.h"
#include "maga_transformer/cpp/utils/ScopeGuard.h"
#include <algorithm>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace rtp_llm {
cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle,
cublasLtHandle_t cublaslt_handle,
cudaStream_t stream,
cublasAlgoMap* cublas_algo_map,
std::mutex* mu,
IAllocator* allocator):
cublas_handle_(cublas_handle),
cublaslt_handle_(cublaslt_handle),
stream_(stream),
cublas_algo_map_(cublas_algo_map),
mutex_(mu),
allocator_(allocator) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE);
}
}
cublasMMWrapper::~cublasMMWrapper() {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
mutex_ = nullptr;
if (allocator_ != nullptr) {
allocator_->free((void**)(&cublas_workspace_));
for (size_t i = 0; i < additional_cublas_workspaces_.size(); i++) {
allocator_->free((void**)(&additional_cublas_workspaces_[i]));
}
allocator_ = nullptr;
}
}
cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper& wrapper):
cublas_handle_(wrapper.cublas_handle_),
cublaslt_handle_(wrapper.cublaslt_handle_),
stream_(wrapper.stream_),
cublas_algo_map_(wrapper.cublas_algo_map_),
mutex_(wrapper.mutex_),
allocator_(wrapper.allocator_) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (allocator_ != nullptr) {
cublas_workspace_ = allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE);
}
}
void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb,
const int m,
const int n,
const int k,
const void* A,
const int lda,
const void* B,
const int ldb,
void* C,
const int ldc,
float f_alpha,
float f_beta,
int math_sm_count,
cudaStream_t stream) {
Gemm(transa, transb, m, n, k, A, Atype_, lda, B, Btype_, ldb, C, Ctype_, ldc, computeType_, f_alpha, f_beta, nullptr, nullptr, math_sm_count, 0, stream);
}
void cublasMMWrapper::cublasLtGemm(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void* alpha, /* host or device pointer */
const void* A,
const void* A_scale,
cudaDataType Atype,
int lda,
const void* B,
const void* B_scale,
cudaDataType Btype,
int ldb,
const void* beta, /* host or device pointer */
void* C,
cudaDataType Ctype,
int ldc,
bool is_fp16_computeType,
cublasLtMatmulAlgo_info info,
bool findAlgo,
int math_sm_count,
int8_t fast_accum,
cudaStream_t stream) {
cublasLtMatrixLayout_t Adesc;
cublasLtMatrixLayout_t Bdesc;
cublasLtMatrixLayout_t Cdesc;
cublasLtMatrixLayout_t Ddesc;
cublasLtMatmulDesc_t operationDesc;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
} else {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, Atype,
transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Adesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, Btype,
transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Bdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Cdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, Btype, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Ddesc); });
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
#else
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
#endif
FT_SCOPE_GUARD([&](){ cublasLtMatmulDescDestroy(operationDesc); });
if (math_sm_count > 0) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count)));
}
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fast_accum, sizeof(int8_t)));
if (A_scale != nullptr) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &A_scale, sizeof(void*)));
}
if (B_scale != nullptr) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale, sizeof(void*)));
}
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
uint64_t workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (stream != stream_) {
if (cublas_workspces_map_.count(stream) == 0) {
void* additional_cublas_workspace = nullptr;
additional_cublas_workspace = allocator_->reMalloc(additional_cublas_workspace, CUBLAS_WORKSPACE_SIZE);
additional_cublas_workspaces_.push_back(additional_cublas_workspace);
cublas_workspces_map_[stream] = additional_cublas_workspaces_.size() - 1;
}
workSpace = additional_cublas_workspaces_[cublas_workspces_map_[stream]];
workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
RTP_LLM_LOG_DEBUG("stream %d, idx %d", stream, cublas_workspces_map_[stream]);
}
if (findAlgo) {
if (info.workspaceSize > workspaceSize) {
findAlgo = 0;
} else {
check_cuda_error(cublasLtMatmulAlgoInit(
cublaslt_handle_, computeType, scaleType, Atype, Btype, Ctype, Ctype, info.algoId, &algo));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(info.reductionScheme),
sizeof(info.reductionScheme)));
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)));
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
&(info.cluster_shapeId),
sizeof(info.cluster_shapeId)));
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)));
#endif
}
}
check_cuda_error(cublasLtMatmulWrapper(cublaslt_handle_,
operationDesc,
alpha,
A,
Adesc,
B,
Bdesc,
beta,
C,
Cdesc,
C,
Cdesc,
(findAlgo == 1 ? (&algo) : NULL),
workSpace,
workspaceSize,
stream,
/* find_best = */false));
sync_check_cuda_error();
}
void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb,
const int m,
const int n,
const int k,
const void* A,
cudaDataType_t Atype,
const int lda,
const void* B,
cudaDataType_t Btype,
const int ldb,
void* C,
cudaDataType_t Ctype,
const int ldc,
cudaDataType_t computeType,
float f_alpha,
float f_beta,
const float* A_scale,
const float* B_scale,
int math_sm_count,
int8_t fast_accum,
cudaStream_t stream) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(*mutex_);
half h_alpha = (half)(f_alpha);
half h_beta = (half)(f_beta);
// TODO: default cublas libs
bool is_fp16_computeType = computeType == CUDA_R_16F ? true : false;
bool using_cublasLt = (Atype == CUDA_R_16F || Atype == CUDA_R_8F_E4M3 || Atype == CUDA_R_16BF) ? true : false;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, getCublasDataType(Atype));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype));
if (findAlgo) {
RTP_LLM_LOG_DEBUG("Using pre-tuned cublasLt algorithm");
if (info.stages != -1) {
using_cublasLt = true;
} else {
using_cublasLt = false;
}
} else {
RTP_LLM_LOG_DEBUG("Fallback to default cublas algorithm");
}
RTP_LLM_LOG_DEBUG("using cublasLt: %d", using_cublasLt);
try {
if (using_cublasLt) {
const void* A_scale_ptr = static_cast<const void*>(A_scale);
const void* B_scale_ptr = static_cast<const void*>(B_scale);
cublasLtGemm(cublas_handle_,
transa,
transb,
m,
n,
k,
alpha,
A,
A_scale_ptr,
Atype,
lda,
B,
B_scale_ptr,
Btype,
ldb,
beta,
C,
Ctype,
ldc,
is_fp16_computeType,
info,
findAlgo,
math_sm_count,
fast_accum,
stream);
} else {
int cublasAlgo = info.algoId;
check_cuda_error(cublasGemmEx(cublas_handle_,
transa,
transb,
m,
n,
k,
alpha,
A,
Atype,
lda,
B,
Btype,
ldb,
beta,
C,
Ctype,
ldc,
computeType,
static_cast<cublasGemmAlgo_t>(cublasAlgo)));
}
sync_check_cuda_error();
} catch (const std::exception& e) {
RTP_LLM_LOG_ERROR("cublasMMWrapper::Gemm exception %s", e.what());
throw;
}
}
void cublasMMWrapper::setFP32GemmConfig() {
Atype_ = CUDA_R_32F;
Btype_ = CUDA_R_32F;
Ctype_ = CUDA_R_32F;
computeType_ = CUDA_R_32F;
}
void cublasMMWrapper::setFP16GemmConfig() {
Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_32F;
}
void cublasMMWrapper::setBF16GemmConfig() {
Atype_ = CUDA_R_16BF;
Btype_ = CUDA_R_16BF;
Ctype_ = CUDA_R_16BF;
computeType_ = CUDA_R_32F;
}
#ifdef ENABLE_FP8
void cublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif
void cublasMMWrapper::setGemmConfig(cudaDataType_t aType,
cudaDataType_t bType,
cudaDataType_t cType,
cudaDataType_t computeType) {
Atype_ = aType;
Btype_ = bType;
Ctype_ = cType;
computeType_ = computeType;
}
CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) {
if (data_type == CUDA_R_16F) {
return HALF_DATATYPE;
} else if (data_type == CUDA_R_32F) {
return FLOAT_DATATYPE;
}
#ifdef ENABLE_BF16
else if (data_type == CUDA_R_16BF) {
return BFLOAT16_DATATYPE;
}
#endif
return FLOAT_DATATYPE;
}
#if (CUDART_VERSION >= 11000)
// input, weight, output are row-major
// only works for cublas 11.x
void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasOperation_t transb,
const int m,
const int n,
const int k,
const void* A,
const int lda,
const void* B,
const int ldb,
const void* bias,
void* C,
const int ldc) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t Atype, Btype, Ctype;
cublasComputeType_t computeType;
cudaDataType_t scaleType;
float alpha_float = 1.0f;
float beta_float = 0.0f;
void * alpha, *beta;
// int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
if (Atype_ == CUDA_R_32F) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_32F;
Btype = CUDA_R_32F;
Ctype = CUDA_R_32F;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
} else if (Atype_ == CUDA_R_16BF) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
Atype = CUDA_R_16BF;
Btype = CUDA_R_16BF;
Ctype = CUDA_R_16BF;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
} else {
computeType = CUBLAS_COMPUTE_32F;
Atype = CUDA_R_16F;
Btype = CUDA_R_16F;
Ctype = CUDA_R_16F;
scaleType = CUDA_R_32F;
alpha = &alpha_float;
beta = &beta_float;
}
int findAlgo = cublas_algo_map_->isExist(1, m, n, k, getCublasDataType(Atype_));
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(1, m, n, k, getCublasDataType(Atype_));
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
uint64_t workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (findAlgo && info.stages != -1 && info.workspaceSize <= workspaceSize) {
check_cuda_error(cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, scaleType, Atype_,
Btype_, Ctype_, Ctype_, info.algoId, &algo));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption),
sizeof(info.customOption)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID,
&(info.tile), sizeof(info.tile)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val),
sizeof(info.splitK_val)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle),
sizeof(info.swizzle)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(info.reductionScheme),
sizeof(info.reductionScheme)));
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)));
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
&(info.cluster_shapeId),
sizeof(info.cluster_shapeId)));
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)));
#endif
} else {
findAlgo = false;
}
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
check_cuda_error(cublasLtMatrixLayoutCreate(
&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, (transa == CUBLAS_OP_N) ? k : m, lda));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Adesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(
&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, (transb == CUBLAS_OP_N) ? n : k, ldb));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Bdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(
&Cdesc, Ctype, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Cdesc); });
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
FT_SCOPE_GUARD([&](){ cublasLtMatmulDescDestroy(operationDesc); });
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void*)));
check_cuda_error(cublasLtMatmul(cublaslt_handle_, operationDesc, alpha, A,
Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc,
(findAlgo == 1 ? (&algo) : NULL), workSpace, workspaceSize, stream_));
sync_check_cuda_error();
}
#endif
void cublasMMWrapper::setStream(cudaStream_t stream) {
stream_ = stream;
}
void cublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa,
cublasOperation_t transb,
const int m,
const int n,
const int k,
const void* A,
const int lda,
const int64_t strideA,
const void* B,
const int ldb,
const int64_t strideB,
void* C,
const int ldc,
const int64_t strideC,
const int batch_count,
const float f_alpha,
const float f_beta) {
std::lock_guard<std::mutex> lock(*mutex_);
half h_alpha = (half)f_alpha;
half h_beta = (half)f_beta;
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* alpha =
is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle_,
transa,
transb,
m,
n,
k,
alpha,
A,
Atype_,
lda,
strideA,
B,
Btype_,
ldb,
strideB,
beta,
C,
Ctype_,
ldc,
strideC,
batch_count,
computeType_,
static_cast<cublasGemmAlgo_t>(info.algoId)));
}
void cublasMMWrapper::batchedGemm(cublasOperation_t transa,
cublasOperation_t transb,
const int m,
const int n,
const int k,
const void* const* A,
const int lda,
const void* const* B,
const int ldb,
void* const* C,
const int ldc,
const int batch_count,
const float alpha,
const float beta) {
std::lock_guard<std::mutex> lock(*mutex_);
float f_alpha = static_cast<float>(alpha);
float f_beta = static_cast<float>(beta);
half h_alpha = (half)alpha;
half h_beta = (half)beta;
int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0;
const void* r_alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* r_beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
check_cuda_error(cublasGemmBatchedEx(cublas_handle_,
transa,
transb,
m,
n,
k,
r_alpha,
A,
Atype_,
lda,
B,
Btype_,
ldb,
r_beta,
C,
Ctype_,
ldc,
batch_count,
computeType_,
static_cast<cublasGemmAlgo_t>(info.algoId)));
}
bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, const int k, const int n) {
CublasDataType data_type = getCublasDataType(Atype_);
if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false
|| cublas_algo_map_->isExist(1, m, k, n, data_type) == false) {
return false;
} else {
return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type).exec_time
< 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time;
}
}
std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findHeuristicAlgo(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc,
const void* alpha,
const void* A,
cublasLtMatrixLayout_t Adesc,
const void* B,
cublasLtMatrixLayout_t Bdesc,
const void* beta,
const void* C,
cublasLtMatrixLayout_t Cdesc,
void* D,
cublasLtMatrixLayout_t Ddesc) {
#if (CUBLAS_VERSION) <= 11402
RTP_LLM_FAIL("CUBLAS version too low.");
return {false, cublasLtMatmulAlgo_t{}};
#else
size_t returnSize;
int32_t pointer_mode;
check_cuda_error(cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize));
cublasLtMatmulHeuristicResult_t result;
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
FT_SCOPE_GUARD([&]() { cublasLtMatmulPreferenceDestroy(preference); });
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
#if (CUBLAS_VERSION) <= 12000
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle,
computeDesc,
Adesc,
Bdesc,
Cdesc,
Ddesc,
preference,
1,
&result,
&return_count);
check_cuda_error(ret);
return {return_count != 0, result.algo};
#endif
}
std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc,
const void* alpha,
const void* A,
cublasLtMatrixLayout_t Adesc,
const void* B,
cublasLtMatrixLayout_t Bdesc,
const void* beta,
const void* C,
cublasLtMatrixLayout_t Cdesc,
void* D,
cublasLtMatrixLayout_t Ddesc,
cudaStream_t stream) {
#if (CUBLAS_VERSION) <= 11402
RTP_LLM_FAIL("CUBLAS version too low.");
return {false, cublasLtMatmulAlgo_t{}};
#else
size_t returnSize;
int32_t pointer_mode;
check_cuda_error(cublasLtMatmulDescGetAttribute(
computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode), &returnSize));
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
FT_SCOPE_GUARD([&]() { cublasLtMatmulPreferenceDestroy(preference); });
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
#if (CUBLAS_VERSION) <= 12000
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
auto ret = cublasLtMatmulAlgoGetHeuristic(lightHandle,
computeDesc,
Adesc,
Bdesc,
Cdesc,
Ddesc,
preference,
heuristics.size(),
heuristics.data(),
&return_count);
check_cuda_error(ret);
heuristics.resize(return_count);
std::map<int, std::vector<float>> algo_results;
cudaEvent_t start_event, stop_event;
check_cuda_error(cudaEventCreate(&start_event));
FT_SCOPE_GUARD([&]() { cudaEventDestroy(start_event); });
check_cuda_error(cudaEventCreate(&stop_event));
FT_SCOPE_GUARD([&]() { cudaEventDestroy(stop_event); });
for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
check_cuda_error(cublasLtMatmulAlgoConfigGetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize));
for (int i = 0; i < 11; i++) {
float duration_ms;
cudaEventRecord(start_event, stream);
check_cuda_error(cublasLtMatmul(lightHandle,
computeDesc,
alpha,
A,
Adesc,
B,
Bdesc,
beta,
C,
Cdesc,
D,
Ddesc,
&algo,
cublas_workspace_,
CUBLAS_WORKSPACE_SIZE,
stream));
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
cudaEventElapsedTime(&duration_ms, start_event, stop_event);
algo_results[algo_id].push_back(duration_ms);
}
std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end());
}
cublasLtMatmulHeuristicResult_t result;
float best_time = INFINITY;
for (const auto& heuristic : heuristics) {
cublasLtMatmulAlgo_t algo = heuristic.algo;
int32_t algo_id;
check_cuda_error(cublasLtMatmulAlgoConfigGetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize));
const auto& results = algo_results[algo_id];
if (results.size() > 0 && results[5] < best_time) {
best_time = results[5];
result = heuristic;
}
}
return {best_time != INFINITY, result.algo};
#endif
}
cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc) {
size_t returnSize;
MatrixLayout m_layout;
check_cuda_error(cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize));
check_cuda_error(cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &std::get<1>(m_layout), sizeof(std::get<1>(m_layout)), &returnSize));
check_cuda_error(cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize));
check_cuda_error(cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize));
return m_layout;
}
cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc,
const void* alpha,
const void* A,
cublasLtMatrixLayout_t Adesc,
const void* B,
cublasLtMatrixLayout_t Bdesc,
const void* beta,
const void* C,
cublasLtMatrixLayout_t Cdesc,
void* D,
cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t* algo,
void* workspace,
size_t workspaceSizeInBytes,
cudaStream_t stream,
bool findBest) {
cache_idx_t cache_idx{
computeDesc,
{createMatrixLayout(Adesc), createMatrixLayout(Bdesc), createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}};
cublasLtMatmulAlgo_t algo_value;
bool found_algo = false;
if (algo == nullptr) {
auto it = algo_cache.find(cache_idx);
if (it == algo_cache.end()) {
std::pair<bool, cublasLtMatmulAlgo_t> result;
if (findBest) {
result = findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, stream);
} else {
result = findHeuristicAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc);
}
if (result.first) {
algo_cache[cache_idx] = result.second;
algo_value = result.second;
found_algo = true;
}
} else {
algo_value = it->second;
found_algo = true;
}
}
return cublasLtMatmul(lightHandle,
computeDesc,
alpha,
A,
Adesc,
B,
Bdesc,
beta,
C,
Cdesc,
D,
Ddesc,
found_algo ? &algo_value : algo,
workspace,
workspaceSizeInBytes,
stream);
}
void cublasMMWrapper::_Int8Gemm(const int m,
const int n,
const int k,
const int8_t* A,
const int lda,
const int8_t* B,
const int ldb,
void* C,
const int ldc,
const void* alpha,
const int mode,
const bool per_column_scaling) {
/* mode:
* - 0: int8 * int8 -> int32 -> int8
* - 1: int8 * int8 -> int32 -> int32
*/
#if (CUBLAS_VERSION) <= 11402
RTP_LLM_FAIL("CUBLAS version too low.");
#else
std::lock_guard<std::mutex> lock(*mutex_);
const auto op_a = CUBLAS_OP_T;
const auto op_b = CUBLAS_OP_N;
const auto dataType = CUDA_R_8I;
const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I;
const auto computeType = CUBLAS_COMPUTE_32I;
const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I;
const void* beta;
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Adesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Bdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Cdesc); });
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
FT_SCOPE_GUARD([&](){ cublasLtMatmulDescDestroy(operationDesc); });
auto pointer_mode = CUBLASLT_POINTER_MODE_HOST;
if (mode == 0) {
pointer_mode =
per_column_scaling ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST : CUBLASLT_POINTER_MODE_DEVICE;
}
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, &op_b, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
const int32_t int_one = 1;
const int32_t int_zero = 0;
const float float_zero = 0;
if (mode == 0) {
beta = per_column_scaling ? &float_zero : NULL;
} else {
alpha = &int_one;
beta = &int_zero;
}
void* workSpace = cublas_workspace_;
uint64_t workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
sync_check_cuda_error();
auto ret = cublasLtMatmulWrapper(cublaslt_handle_,
operationDesc,
alpha,
A,
Adesc,
B,
Bdesc,
beta,
C,
Cdesc,
C,
Cdesc,
NULL,
workSpace,
workspaceSize,
stream_);
check_cuda_error(ret);
sync_check_cuda_error();
#endif
}
void cublasMMWrapper::Int8Gemm(const int m,
const int n,
const int k,
const int8_t* A,
const int lda,
const int8_t* B,
const int ldb,
int8_t* C,
const int ldc,
const float* alpha,
const bool per_column_scaling) {
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, per_column_scaling);
}
void cublasMMWrapper::Int8Gemm(const int m,
const int n,
const int k,
const int8_t* A,
const int lda,
const int8_t* B,
const int ldb,
int32_t* C,
const int ldc) {
return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float*)nullptr, 1, false);
}
} // namespace rtp_llm