in maga_transformer/cpp/cuda/cublas/cublasMMWrapper.cc [87:256]
void cublasMMWrapper::cublasLtGemm(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void* alpha, /* host or device pointer */
const void* A,
const void* A_scale,
cudaDataType Atype,
int lda,
const void* B,
const void* B_scale,
cudaDataType Btype,
int ldb,
const void* beta, /* host or device pointer */
void* C,
cudaDataType Ctype,
int ldc,
bool is_fp16_computeType,
cublasLtMatmulAlgo_info info,
bool findAlgo,
int math_sm_count,
int8_t fast_accum,
cudaStream_t stream) {
cublasLtMatrixLayout_t Adesc;
cublasLtMatrixLayout_t Bdesc;
cublasLtMatrixLayout_t Cdesc;
cublasLtMatrixLayout_t Ddesc;
cublasLtMatmulDesc_t operationDesc;
cudaDataType_t scaleType;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t computeType;
#else
cudaDataType_t computeType;
#endif
if (is_fp16_computeType) {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_16F;
#else
computeType = CUDA_R_16F;
#endif
scaleType = CUDA_R_16F;
} else {
#if (CUDART_VERSION >= 11000)
computeType = CUBLAS_COMPUTE_32F;
#else
computeType = CUDA_R_32F;
#endif
scaleType = CUDA_R_32F;
}
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, Atype,
transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Adesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, Btype,
transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Bdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Cdesc); });
check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, Btype, m, n, ldc));
FT_SCOPE_GUARD([&](){ cublasLtMatrixLayoutDestroy(Ddesc); });
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType));
#else
check_cuda_error(cublasLtMatmulDescCreate(&operationDesc, computeType));
#endif
FT_SCOPE_GUARD([&](){ cublasLtMatmulDescDestroy(operationDesc); });
if (math_sm_count > 0) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count)));
}
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fast_accum, sizeof(int8_t)));
if (A_scale != nullptr) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &A_scale, sizeof(void*)));
}
if (B_scale != nullptr) {
check_cuda_error(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale, sizeof(void*)));
}
cublasLtMatmulAlgo_t algo;
void* workSpace = cublas_workspace_;
uint64_t workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (stream != stream_) {
if (cublas_workspces_map_.count(stream) == 0) {
void* additional_cublas_workspace = nullptr;
additional_cublas_workspace = allocator_->reMalloc(additional_cublas_workspace, CUBLAS_WORKSPACE_SIZE);
additional_cublas_workspaces_.push_back(additional_cublas_workspace);
cublas_workspces_map_[stream] = additional_cublas_workspaces_.size() - 1;
}
workSpace = additional_cublas_workspaces_[cublas_workspces_map_[stream]];
workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
RTP_LLM_LOG_DEBUG("stream %d, idx %d", stream, cublas_workspces_map_[stream]);
}
if (findAlgo) {
if (info.workspaceSize > workspaceSize) {
findAlgo = 0;
} else {
check_cuda_error(cublasLtMatmulAlgoInit(
cublaslt_handle_, computeType, scaleType, Atype, Btype, Ctype, Ctype, info.algoId, &algo));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), sizeof(info.customOption)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(info.tile), sizeof(info.tile)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), sizeof(info.splitK_val)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), sizeof(info.swizzle)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(info.reductionScheme),
sizeof(info.reductionScheme)));
#if (CUDART_VERSION >= 11000)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), sizeof(info.stages)));
#endif
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), sizeof(info.inner_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(&algo,
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID,
&(info.cluster_shapeId),
sizeof(info.cluster_shapeId)));
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), sizeof(info.mma_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), sizeof(info.cga_shapeId)));
check_cuda_error(cublasLtMatmulAlgoConfigSetAttribute(
&algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), sizeof(info.sche_mode)));
#endif
}
}
check_cuda_error(cublasLtMatmulWrapper(cublaslt_handle_,
operationDesc,
alpha,
A,
Adesc,
B,
Bdesc,
beta,
C,
Cdesc,
C,
Cdesc,
(findAlgo == 1 ? (&algo) : NULL),
workSpace,
workspaceSize,
stream,
/* find_best = */false));
sync_check_cuda_error();
}