in maga_transformer/cpp/devices/arm_impl/ArmGemmKaiOp.cc [299:475]
BufferPtr ArmCpuDevice::gemm_kai_a8w4(const GemmParams& params) {
#ifdef GEMM_DEBUG
auto start = std::chrono::high_resolution_clock::now();
#endif
params.check();
std::vector<size_t> Ashape;
std::vector<size_t> Bshape;
std::vector<size_t> Dshape;
size_t dim;
size_t m;
size_t k;
size_t n;
Ashape = params.A.shape();
Bshape = params.B.shape();
dim = params.A.dim();
if (params.transA == TransposeOperation::TRANSPOSE) {
std::iter_swap(Ashape.end() - 1, Ashape.end() - 2);
}
if (params.transB == TransposeOperation::TRANSPOSE) {
std::iter_swap(Bshape.end() - 1, Bshape.end() - 2);
}
m = Ashape[dim - 2];
k = Ashape[dim - 1];
n = Bshape[dim - 1];
auto data_type = params.compute_type == DataType::TYPE_INVALID ? params.A.type() : params.compute_type;
Dshape = std::vector<size_t>(Ashape.begin(), Ashape.end() - 2);
Dshape.insert(Dshape.end(), {m, n});
BufferPtr output;
if (params.D) {
output = params.D;
RUNTIME_ASSERT_OP_ARG((data_type == params.D->type()) && (Dshape == params.D->shape()),
"Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]",
data_type,
autil::StringUtil::toString(Dshape).c_str(),
params.D->debugString().c_str());
} else {
output = allocateBuffer({data_type, Dshape, AllocationType::DEVICE}, {"gemm_output"});
}
size_t idx_variant = 0;
// input FP16 or output FP16 case, currently support gemv only
if (m == 1) {
idx_variant = 0;
} else {
idx_variant = 1;
}
// Get the packing parameters
size_t mr;
size_t kr;
size_t sr;
if (data_type == DataType::TYPE_FP32) {
mr = fp32_ukernel_variants[idx_variant].ukernel.get_mr();
kr = fp32_ukernel_variants[idx_variant].ukernel.get_kr();
sr = fp32_ukernel_variants[idx_variant].ukernel.get_sr();
} else if (data_type == DataType::TYPE_FP16) {
mr = fp16_ukernel_variants[idx_variant].ukernel.get_mr();
kr = fp16_ukernel_variants[idx_variant].ukernel.get_kr();
sr = fp16_ukernel_variants[idx_variant].ukernel.get_sr();
} else {
RTP_LLM_LOG_WARNING("Not supported GEMM output type %d", data_type);
}
const size_t lhs_stride = k * sizeof(float);
const size_t dst_stride_col = sizeof(float);
const size_t bl = 32;
const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32(m, k, bl, mr, kr, sr);
uint8_t* lhs_packed_mtx_qs8d32 = new uint8_t[lhs_packed_size];
uint8_t* rhs_packed_mtx_qs4c32 = (uint8_t*)params.B.data();
float* lhs = (float* )params.A.data();
int n_step = 32; // 32 is the best for performance
int m_step = mr;
// LHS packing
if (params.A.type() == DataType::TYPE_FP32) {
#pragma omp parallel for if (m > 1)
for (int m_start = 0; m_start < m; m_start += m_step) {
const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(m_start, lhs_stride);
const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32(m_start, k, bl, mr, kr, sr);
int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;
kai_run_lhs_quant_pack_qsi8d32p_f32(
tile_m, k, bl, mr, kr, sr, 0,
(const float*)((uint8_t*)lhs + lhs_offset),
lhs_stride,
((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset));
}
} else if (params.A.type() == DataType::TYPE_FP16) {
#pragma omp parallel for if (m > 1)
for (int m_start = 0; m_start < m; m_start += m_step) {
const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k * sizeof(float16_t));
const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k, bl, mr, kr, sr);
int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;
kai_run_lhs_quant_pack_qsi8d32p_f16(
tile_m, k, bl, mr, kr, sr, 0,
(const float16_t*)((uint8_t*)lhs + lhs_offset),
k * sizeof(float16_t),
((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset));
}
} else {
RTP_LLM_LOG_WARNING("Not supported GEMM A type %d", params.A.type());
}
// Matmul
if (data_type == DataType::TYPE_FP32) {
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t dst_stride = n * sizeof(float);
const size_t lhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl);
const size_t rhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl);
const size_t dst_offset = fp32_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride);
const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset);
const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset);
float* dst_ptr = (float*)((uint8_t*)output->data() + dst_offset);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
fp32_ukernel_variants[idx_variant].ukernel.run_matmul(
m, tile_n, k, bl, // Dimensions
lhs_ptr, // LHS packed
rhs_ptr, // RHS packed
dst_ptr, // DST
dst_stride, // DST stride (row)
dst_stride_col, // DST stride (col)
-FLT_MAX, FLT_MAX // Min and max for the clamp operation
);
}
} else if (data_type == DataType::TYPE_FP16) {
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t dst_stride = n * sizeof(float16_t);
const size_t lhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl);
const size_t rhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl);
const size_t dst_offset = fp16_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride);
const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset);
const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset);
float16_t* dst_ptr = (float16_t*)((uint8_t*)output->data() + dst_offset);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
fp16_ukernel_variants[idx_variant].ukernel.run_matmul(
m, tile_n, k, bl, // Dimensions
lhs_ptr, // LHS packed
rhs_ptr, // RHS packed
dst_ptr, // DST
dst_stride, // DST stride (row)
sizeof(float16_t), // DST stride (col)
-HALF_FLT_MAX, HALF_FLT_MAX // Min and max for the clamp operation
);
}
}
delete[] lhs_packed_mtx_qs8d32;
#ifdef GEMM_DEBUG
auto end = std::chrono::high_resolution_clock::now();
float during_time = std::chrono::duration<float>(end - start).count();
printf("gemm_kai_a8w4 m,n,k %ld %ld %ld %.3f\n", m, n, k, during_time * 1000);
#endif
return output;
}