in maga_transformer/cpp/devices/rocm_impl/ROCmGemmOp.cc [167:343]
BufferPtr ROCmDevice::gemm(const GemmParams& params) {
params.check();
using GemmImplementType = ROCmGemmDispatch::GemmImplementType;
ROCmGemmArguments arguments(params);
BufferPtr output;
if (params.D) {
output = params.D;
RUNTIME_ASSERT_OP_ARG((arguments.DDtype == params.D->type()) && (arguments.Dshape == params.D->shape()),
"Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]",
arguments.DDtype,
autil::StringUtil::toString(arguments.Dshape).c_str(),
params.D->debugString().c_str());
} else {
output = allocateBuffer({arguments.DDtype, arguments.Dshape, AllocationType::DEVICE}, {"gemm_output"});
}
if (params.dispatch() == GemmType::BufferA_QBufferB_BufferC_2DGemm) {
if (reinterpret_cast<const QBuffer&>(params.B).zerosData() != nullptr) {
ROCM_CHECK_VALUE(reinterpret_cast<const QBuffer&>(params.B).scales().dim() == 2,
"scales().dim() = %d", reinterpret_cast<const QBuffer&>(params.B).scales().dim());
size_t kernel_dim0 = params.B.shape()[0];
size_t scales_dim0 = reinterpret_cast<const QBuffer&>(params.B).scales().shape()[0];
ROCM_CHECK_VALUE((kernel_dim0 % scales_dim0 == 0),
"kernel_dim0 % scales_dim0 != 0");
size_t group_size = (kernel_dim0 / scales_dim0);
ROCM_CHECK_VALUE((group_size == 64 || group_size == 128),
"group_size != 64 and group_size != 128");
size_t type_bits = getTypeBits(params.B.type());
ROCM_CHECK_VALUE((type_bits == 4 || type_bits == 8),
"type_bits != 4 and type_bits != 8");
BUFFER_DTYPE_CHECK(params.A, {DataType::TYPE_FP16, DataType::TYPE_BF16});
BUFFER_DTYPE_CHECK(params.B, {DataType::TYPE_QINT4X2});
const QBuffer& QB = reinterpret_cast<const QBuffer&>(params.B);
auto fpB = allocateBuffer({params.A.type(), {params.B.shape()}, AllocationType::DEVICE}, {"fpB"});
#if USING_CK_INT4
// Using CK int4-dequant fusion Gemm kernel
auto ck_gemm_params = ckGemmParam({params.A.data(),
QB.kernel().data(),
QB.scales().data(),
QB.zeros().data(),
output->data(),
arguments.m,
arguments.n,
arguments.k,
group_size,
arguments.k, // arguments.lda,
arguments.k, // arguments.ldb,
arguments.n, // arguments.ldc,
stream_});
ck_gemm_runner_->runCKGemm(ck_gemm_params,params.A.type(),params.B.type());
#else
// dequant B
DISPATCH_CUDA_FUNCTION_DATA_TYPE(params.A.type(),
invokePerColDequantizationInt4x2,
fpB.get()->data(),
arguments.k,
arguments.n,
group_size,
(int8_t*)(QB.kernel().data()),
QB.scales().data<half>(),
QB.zeros().data<half>(),
stream_);
const auto A = params.A.data();
const auto B = fpB.get()->data();
auto D = output->data();
auto a_op = opConvert(params.transA);
auto b_op = opConvert(params.transB);
auto A_data_type = dtypeConvert(arguments.ADtype);
auto B_data_type = dtypeConvert(fpB.get()->type());
auto D_data_type = dtypeConvert(arguments.DDtype);
auto computeType = dtypeConvert(arguments.DDtype);
hipblas_mm_wrapper_->stridedBatchedGemm(b_op,
a_op,
arguments.n,
arguments.m,
arguments.k,
arguments.alpha,
B,
B_data_type,
arguments.ldb,
arguments.stride_b,
A,
A_data_type,
arguments.lda,
arguments.stride_a,
arguments.beta,
D,
D_data_type,
arguments.ldc,
arguments.stride_c,
arguments.batch_size,
computeType);
#endif
return move(output);
} else {
ROCM_FAIL("[GEMM]: Other weight quantization not implemented");
}
}
auto A_data_type = dtypeConvert(arguments.ADtype);
auto B_data_type = dtypeConvert(arguments.BDtype);
auto D_data_type = dtypeConvert(arguments.DDtype);
auto computeType = HIPBLAS_R_32F;
if (params.compute_type == DataType::TYPE_INVALID) {
computeType = HIPBLAS_R_32F;
hipblasMMWrapperPtr()->setGemmConfig(A_data_type, B_data_type, D_data_type, HIPBLAS_R_32F);
} else {
computeType = dtypeConvert(arguments.DDtype);
hipblasMMWrapperPtr()->setGemmConfig(A_data_type, B_data_type, D_data_type, dtypeConvert(params.compute_type));
}
if (ROCmGemmDispatch::dispatch(params) == GemmImplementType::hipblas_basic_gemm) {
const auto A = params.A.data();
const auto B = params.B.data();
auto D = output->data();
auto a_op = opConvert(params.transA);
auto b_op = opConvert(params.transB);
hipblas_mm_wrapper_->setStream(current_stream_);
hipblas_mm_wrapper_->Gemm(
b_op, a_op, arguments.n, arguments.m, arguments.k, B, arguments.ldb, A, arguments.lda, D, arguments.ldc);
return std::move(output);
} else if (ROCmGemmDispatch::dispatch(params) == GemmImplementType::hipblas_batch_gemm) {
// convert buffers to ptrs
const auto A = params.A.data();
const auto B = params.B.data();
auto D = output->data();
auto a_op = opConvert(params.transA);
auto b_op = opConvert(params.transB);
auto A_data_type = dtypeConvert(arguments.ADtype);
auto B_data_type = dtypeConvert(arguments.BDtype);
auto D_data_type = dtypeConvert(arguments.DDtype);
auto computeType = dtypeConvert(arguments.DDtype);
hipblas_mm_wrapper_->stridedBatchedGemm(b_op,
a_op,
arguments.n,
arguments.m,
arguments.k,
arguments.alpha,
B,
B_data_type,
arguments.ldb,
arguments.stride_b,
A,
A_data_type,
arguments.lda,
arguments.stride_a,
arguments.beta,
D,
D_data_type,
arguments.ldc,
arguments.stride_c,
arguments.batch_size,
computeType);
return std::move(output);
} else {
ROCM_FAIL("[GEMM]:other dispatch not implemented");
}
return std::move(output);
}