in src/runtime/contrib/cublas/cublas.cc [139:306]
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
size_t workspace_size, cublasLtEpilogue_t epilogue,
std::optional<float> dq_scale) {
ICHECK(TypeEqual(A->dtype, B->dtype));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
auto compute_type = CUBLAS_COMPUTE_32F;
auto scale_type = CUDA_R_32F;
cudaDataType_t ab_type = CUDA_R_32F;
cudaDataType_t c_type = CUDA_R_32F;
float one_fp32 = 1.0;
float zero_fp32 = 0.0;
int32_t one_i32 = 1;
int32_t zero_i32 = 0;
// Pass dequantization scale through the "alpha" parameter. If there is no dequantization after
// matmul, then alpha == 1.0
float alpha_value = dq_scale.value_or(one_fp32);
void* alpha = &alpha_value;
void* beta = &zero_fp32;
if (TypeMatch(A->dtype, kDLFloat, 16)) {
ab_type = CUDA_R_16F;
} else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
ab_type = CUDA_R_16BF;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
} else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8));
ab_type = CUDA_R_8F_E4M3;
}
if (TypeMatch(C->dtype, kDLFloat, 16)) {
c_type = CUDA_R_16F;
} else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
c_type = CUDA_R_16BF;
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
c_type = CUDA_R_32I;
compute_type = CUBLAS_COMPUTE_32I;
scale_type = CUDA_R_32I;
alpha = &one_i32;
beta = &zero_i32;
}
cublasLtMatmulDesc_t op_desc;
cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA,
&op_transb, sizeof(op_transb)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB,
&op_transa, sizeof(op_transa)));
if (bias != nullptr) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias->data, sizeof(float*)));
}
if (scaleA != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data, sizeof(float*)));
}
if (scaleB != nullptr) {
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data, sizeof(float*)));
}
if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
}
int batch_offset_A = A->ndim - 2;
int batch_offset_B = B->ndim - 2;
int M = ColumnCount(B, transb, batch_offset_B);
int N = RowCount(A, transa, batch_offset_A);
int K = ColumnCount(A, transa, batch_offset_A);
bool use_batched_gemm = A->ndim > 2 || B->ndim > 2;
// If A is batched but B is not, flatten all non-reduction axes of A to use the regular GEMM.
// This trick is only applicable if batch axes and the other spatial axis (M or N) are
// adjacent in both the input and the output matrix. In particular, if A is of shape (M, K)
// and B matrix is of shape (Batch, N, K) with transb = true, the output shape
// is (Batch, M, N). Since the Batch and the N axes are not adjacent in the output, we cannot
// use the regular GEMM if only B is batched.
if (A->ndim > 2 && B->ndim == 2 && transa == false) {
N = 1;
for (int i = 0; i < A->ndim - 1; ++i) {
N *= A->shape[i];
}
use_batched_gemm = false;
}
int lda = transb ? K : M;
int ldb = transa ? N : K;
int ldc = M;
cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
CHECK_CUBLAS_ERROR(
cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda));
CHECK_CUBLAS_ERROR(
cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
if (use_batched_gemm) {
auto get_batch_count = [](int64_t* shape, int batch_offset) {
int64_t count = 1;
for (int i = 0; i < batch_offset; ++i) {
count *= shape[i];
}
return count;
};
auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_CUBLAS_ERROR(
cublasLtMatrixLayoutSetAttribute(mat_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride, sizeof(batch_stride)));
};
int batch_count_A = get_batch_count(A->shape, batch_offset_A);
int batch_count_B = get_batch_count(B->shape, batch_offset_B);
int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
int64_t batch_stride_A = M * K;
int64_t batch_stride_B = K * N;
int64_t batch_stride_C = M * N;
// cuBLASLt does not seem to support batched GEMM with one of matrices having
// one batch (with batch_stride 0).
ICHECK_EQ(batch_count_A, batch_count_B);
set_batch(A_desc, batch_count_A, batch_stride_A);
set_batch(B_desc, batch_count_B, batch_stride_B);
set_batch(C_desc, batch_count_C, batch_stride_C);
}
auto A_data = static_cast<char*>(A->data) + A->byte_offset;
auto B_data = static_cast<char*>(B->data) + B->byte_offset;
auto C_data = static_cast<char*>(C->data) + C->byte_offset;
cublasLtMatmulPreferenceSetAttribute(matmul_pref_desc, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(size_t));
cublasLtMatmulHeuristicResult_t heuristic_result = {};
int returned_result = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc,
matmul_pref_desc, 1, &heuristic_result,
&returned_result));
if (returned_result == 0) {
CHECK_CUBLAS_ERROR(CUBLAS_STATUS_NOT_SUPPORTED);
}
CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta,
C_data, C_desc, C_data, C_desc, &heuristic_result.algo,
workspace_ptr, workspace_size, stream));
cublasLtMatmulDescDestroy(op_desc);
cublasLtMatrixLayoutDestroy(A_desc);
cublasLtMatrixLayoutDestroy(B_desc);
cublasLtMatrixLayoutDestroy(C_desc);
}