in maga_transformer/cpp/rocm/int4_gemm_kernels/int4_dequant_comm.h [123:201]
void int4Gemm_impl(const ckGemmParam& params) {
// Get input information.
auto M = params.M;
auto N = params.N;
auto K = params.K;
auto GroupSize = params.Group_size;
auto StrideA = params.StrideA;
auto StrideB = params.StrideB;
auto StrideC = params.StrideC;
auto StrideScaleB = (K + params.Group_size - 1) / params.Group_size;
auto KBatch = 1;
// for KBatch tuning
if (N == 29696 && K == 8192) {
if (M == 1 || M == 16 || M == 48 || M == 64) {
KBatch = 2;
} else if (M == 32) {
KBatch = 4;
}
} else if (N == 8192 && K == 29696) {
if (M == 64) {
KBatch = 2;
} else if (M == 16 || M == 48) {
KBatch = 4;
} else if (M == 1 || M == 32 || M == 80 || M == 96 || M == 112 || M == 128) {
KBatch = 8;
}
} else if (N == 10240 && K == 8192) {
if (M == 16 || M == 80 || M == 112 || M == 128) {
KBatch = 2;
} else if (M == 1 || M == 32) {
KBatch = 4;
}
} else if (N == 8192 && K == 8192) {
if (M == 112 || M == 128) {
KBatch = 2;
} else if (M == 1 || M == 16) {
KBatch = 4;
}
}
// Create gemm launcher and arguments.
auto gemm = DeviceInt4GemmInstance{};
auto invoker = gemm.MakeInvoker();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto argument = gemm.MakeArgument(static_cast<ADataType*>(params.A_input),
static_cast<BDataType*>(params.B_input),
static_cast<CDataType*>(params.C_input),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideScaleB,
static_cast<BScaleDataType*>(params.B_scales_input),
KBatch,
a_element_op,
b_element_op,
c_element_op);
if (!gemm.IsSupportedArgument(argument)) {
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return;
}
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
if (workspace_size != 0) {
DeviceMem gemm_desc_workspace(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
}
invoker.Run(argument, StreamConfig{params.stream, false});
}