maga_transformer/cpp/devices/arm_impl/ArmGemmKaiOp.cc (385 lines of code) (raw):

#include "maga_transformer/cpp/devices/arm_impl/ArmDevice.h" #include "maga_transformer/cpp/devices/DeviceFactory.h" #include "maga_transformer/cpp/core/allocator.h" #include "maga_transformer/cpp/core/cpu_allocator.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" #include <cstring> #include "autil/StringUtil.h" #include "gemm_opt/ArmGemmKernel.h" #include <cfloat> #include "maga_transformer/cpp/devices/utils/Timer.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f16.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qsi4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qsi4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qsi4c32p/kai_matmul_clamp_f16_qsi8d32p_qsi4c32p_interface.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" #include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" namespace rtp_llm { static const float HALF_FLT_MAX = 65504.F; struct kai_matmul_ukernel_f32_qa8d32p_qs4c32p { kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_ukernel ukernel; std::string name = {}; }; struct kai_matmul_ukernel_f16_qa8d32p_qs4c32p { kai_matmul_clamp_f16_qsi8d32p_qsi4c32p_ukernel ukernel; std::string name = {}; }; kai_matmul_ukernel_f32_qa8d32p_qs4c32p fp32_ukernel_variants[] = { {kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, "matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod"}, {kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, "matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm"}, }; kai_matmul_ukernel_f16_qa8d32p_qs4c32p fp16_ukernel_variants[] = { {kai_get_m_step_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_n_step_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_mr_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_nr_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_kr_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_sr_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, kai_run_matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, "matmul_clamp_f16_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod"}, {kai_get_m_step_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_n_step_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_mr_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_nr_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_kr_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_sr_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_dst_offset_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_get_dst_size_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, kai_run_matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, "matmul_clamp_f16_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm"}, }; /// @brief basic gemm ops /// @details D = alpha * op(A) * op(B) + beta * C /// A [b, ..., m, k] /// B [b, ..., k, n] /// C [b, ..., m, n] BufferPtr ArmCpuDevice::gemm_kai_bf16(const GemmParams& params) { #ifdef GEMM_DEBUG auto start = std::chrono::high_resolution_clock::now(); #endif params.check(); std::vector<size_t> Ashape; std::vector<size_t> Bshape; std::vector<size_t> Dshape; size_t dim; size_t m; size_t k; size_t n; Ashape = params.A.shape(); Bshape = params.B.shape(); dim = params.A.dim(); if (params.transA == TransposeOperation::TRANSPOSE) { std::iter_swap(Ashape.end() - 1, Ashape.end() - 2); } if (params.transB == TransposeOperation::TRANSPOSE) { std::iter_swap(Bshape.end() - 1, Bshape.end() - 2); } m = Ashape[dim - 2]; k = Ashape[dim - 1]; n = Bshape[dim - 1]; auto data_type = params.compute_type == DataType::TYPE_INVALID ? params.A.type() : params.compute_type; Dshape = std::vector<size_t>(Ashape.begin(), Ashape.end() - 2); Dshape.insert(Dshape.end(), {m, n}); BufferPtr output; if (params.D) { output = params.D; RUNTIME_ASSERT_OP_ARG((data_type == params.D->type()) && (Dshape == params.D->shape()), "Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]", data_type, autil::StringUtil::toString(Dshape).c_str(), params.D->debugString().c_str()); } else { output = allocateBuffer({data_type, Dshape, AllocationType::DEVICE}, {"gemm_output"}); } const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); uint8_t* rhs_packed; uint8_t* lhs_packed; float* lhs = (float* )params.A.data(); // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. int n_step = nr; rhs_packed = (uint8_t* )params.B.data(); float* dst = (float* )output->data(); int m_step = mr; if (params.A.type() == DataType::TYPE_FP32) { // lhs in fp32 const size_t lhs_stride = k * sizeof(float); const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(m, k, mr, kr, sr); lhs_packed = new uint8_t[lhs_packed_size]; #pragma omp parallel for if (m > 1) for (int m_start = 0; m_start < m; m_start += m_step) { const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon(m_start, lhs_stride); const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_bf16p8x4_f32_neon(m_start, k, mr, kr, sr); int tile_m = (m_start + m_step <= m) ? m_step : m - m_start; kai_run_lhs_quant_pack_bf16p8x4_f32_neon( tile_m, k, mr, kr, sr, 0 /* m_idx_start; should stay as 0 */, ((uint8_t*)lhs + lhs_offset), // adjust Lhs start position lhs_stride, (lhs_packed + lhs_packed_offset)); } } else if (params.A.type() == DataType::TYPE_FP16) { // lhs in fp16 const size_t lhs_stride = k * sizeof(float16_t); const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon(m, k, mr, kr, sr); lhs_packed = new uint8_t[lhs_packed_size]; #pragma omp parallel for if (m > 1) for (int m_start = 0; m_start < m; m_start += m_step) { const size_t lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon(m_start, lhs_stride); const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(m_start, k, mr, kr, sr); int tile_m = (m_start + m_step <= m) ? m_step : m - m_start; kai_run_lhs_pack_bf16p8x4_f16_neon( tile_m, k, mr, kr, sr, 0 /* m_idx_start; should stay as 0 */, ((uint8_t*)lhs + lhs_offset), // adjust Lhs start position lhs_stride, (lhs_packed + lhs_packed_offset)); } } else { RTP_LLM_LOG_WARNING("Not supported GEMM input type %d", params.A.type()); } if (data_type == DataType::TYPE_FP32) { // matmul out fp32 const size_t dst_stride_row = n * sizeof(float); const size_t dst_stride_col = sizeof(float); #pragma omp parallel for for (int n_start = 0; n_start < n; n_start += n_step) { size_t lhs_offset; size_t rhs_offset; size_t dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(0, n_start, n * sizeof(float)); if (params.A.type() == DataType::TYPE_FP32) { lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_bf16p8x4_f32_neon(0, k, mr, kr, sr); rhs_offset = kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n_start, k, nr, kr); } else { // For input type FP16 and compute type FP32. lhs_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(0, k, mr, kr, sr); rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n_start, k); } const void* lhs_ptr = (const void*)((const uint8_t*)lhs_packed + lhs_offset); const void* rhs_ptr = (const void*)((const uint8_t*)rhs_packed + rhs_offset); void* dst_ptr = (void*)((uint8_t*)dst + dst_offset); assert(n % n_step == 0); assert(n_step % n_step == 0); int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( m, tile_n, k, // Dimensions lhs_ptr, // LHS rhs_ptr, // RHS packed dst_ptr, // DST dst_stride_row, // DST stride (row) dst_stride_col, // DST stride (col) -FLT_MAX, FLT_MAX // Min and max for the clamp operation ); } } else if (data_type == DataType::TYPE_FP16) { // matmul out fp16 const size_t dst_stride_row = n * sizeof(float16_t); const size_t dst_stride_col = sizeof(float16_t); #pragma omp parallel for for (int n_start = 0; n_start < n; n_start += n_step) { size_t lhs_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(0, k, mr, kr, sr); size_t rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n_start, k); size_t dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla(0, n_start, n * sizeof(bfloat16_t)); const void* lhs_ptr = (const void*)((const uint8_t*)lhs_packed + lhs_offset); const void* rhs_ptr = (const void*)((const uint8_t*)rhs_packed + rhs_offset); void* dst_ptr = (void*)((uint8_t*)dst + dst_offset); assert(n % n_step == 0); assert(n_step % n_step == 0); int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla( m, tile_n, k, // Dimensions lhs_ptr, // LHS rhs_ptr, // RHS packed dst_ptr, // DST dst_stride_row, // DST stride (row) dst_stride_col, // DST stride (col) -FLT_MAX, FLT_MAX // Min and max for the clamp operation ); } } else { RTP_LLM_LOG_WARNING("Not supported GEMM output type %d", data_type); } delete[] lhs_packed; /* TODO if (m == 1) { // gemv } else { // gemm } */ #ifdef GEMM_DEBUG auto end = std::chrono::high_resolution_clock::now(); float during_time = std::chrono::duration<float>(end - start).count(); printf("gemm_kai_bf16 m,n,k %ld %ld %ld %.3f\n", m, n, k, during_time * 1000); #endif return output; } BufferPtr ArmCpuDevice::gemm_kai_a8w4(const GemmParams& params) { #ifdef GEMM_DEBUG auto start = std::chrono::high_resolution_clock::now(); #endif params.check(); std::vector<size_t> Ashape; std::vector<size_t> Bshape; std::vector<size_t> Dshape; size_t dim; size_t m; size_t k; size_t n; Ashape = params.A.shape(); Bshape = params.B.shape(); dim = params.A.dim(); if (params.transA == TransposeOperation::TRANSPOSE) { std::iter_swap(Ashape.end() - 1, Ashape.end() - 2); } if (params.transB == TransposeOperation::TRANSPOSE) { std::iter_swap(Bshape.end() - 1, Bshape.end() - 2); } m = Ashape[dim - 2]; k = Ashape[dim - 1]; n = Bshape[dim - 1]; auto data_type = params.compute_type == DataType::TYPE_INVALID ? params.A.type() : params.compute_type; Dshape = std::vector<size_t>(Ashape.begin(), Ashape.end() - 2); Dshape.insert(Dshape.end(), {m, n}); BufferPtr output; if (params.D) { output = params.D; RUNTIME_ASSERT_OP_ARG((data_type == params.D->type()) && (Dshape == params.D->shape()), "Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]", data_type, autil::StringUtil::toString(Dshape).c_str(), params.D->debugString().c_str()); } else { output = allocateBuffer({data_type, Dshape, AllocationType::DEVICE}, {"gemm_output"}); } size_t idx_variant = 0; // input FP16 or output FP16 case, currently support gemv only if (m == 1) { idx_variant = 0; } else { idx_variant = 1; } // Get the packing parameters size_t mr; size_t kr; size_t sr; if (data_type == DataType::TYPE_FP32) { mr = fp32_ukernel_variants[idx_variant].ukernel.get_mr(); kr = fp32_ukernel_variants[idx_variant].ukernel.get_kr(); sr = fp32_ukernel_variants[idx_variant].ukernel.get_sr(); } else if (data_type == DataType::TYPE_FP16) { mr = fp16_ukernel_variants[idx_variant].ukernel.get_mr(); kr = fp16_ukernel_variants[idx_variant].ukernel.get_kr(); sr = fp16_ukernel_variants[idx_variant].ukernel.get_sr(); } else { RTP_LLM_LOG_WARNING("Not supported GEMM output type %d", data_type); } const size_t lhs_stride = k * sizeof(float); const size_t dst_stride_col = sizeof(float); const size_t bl = 32; const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32(m, k, bl, mr, kr, sr); uint8_t* lhs_packed_mtx_qs8d32 = new uint8_t[lhs_packed_size]; uint8_t* rhs_packed_mtx_qs4c32 = (uint8_t*)params.B.data(); float* lhs = (float* )params.A.data(); int n_step = 32; // 32 is the best for performance int m_step = mr; // LHS packing if (params.A.type() == DataType::TYPE_FP32) { #pragma omp parallel for if (m > 1) for (int m_start = 0; m_start < m; m_start += m_step) { const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(m_start, lhs_stride); const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32(m_start, k, bl, mr, kr, sr); int tile_m = (m_start + m_step <= m) ? m_step : m - m_start; kai_run_lhs_quant_pack_qsi8d32p_f32( tile_m, k, bl, mr, kr, sr, 0, (const float*)((uint8_t*)lhs + lhs_offset), lhs_stride, ((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset)); } } else if (params.A.type() == DataType::TYPE_FP16) { #pragma omp parallel for if (m > 1) for (int m_start = 0; m_start < m; m_start += m_step) { const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k * sizeof(float16_t)); const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k, bl, mr, kr, sr); int tile_m = (m_start + m_step <= m) ? m_step : m - m_start; kai_run_lhs_quant_pack_qsi8d32p_f16( tile_m, k, bl, mr, kr, sr, 0, (const float16_t*)((uint8_t*)lhs + lhs_offset), k * sizeof(float16_t), ((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset)); } } else { RTP_LLM_LOG_WARNING("Not supported GEMM A type %d", params.A.type()); } // Matmul if (data_type == DataType::TYPE_FP32) { #pragma omp parallel for for (int n_start = 0; n_start < n; n_start += n_step) { const size_t dst_stride = n * sizeof(float); const size_t lhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl); const size_t rhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl); const size_t dst_offset = fp32_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride); const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset); const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset); float* dst_ptr = (float*)((uint8_t*)output->data() + dst_offset); int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; fp32_ukernel_variants[idx_variant].ukernel.run_matmul( m, tile_n, k, bl, // Dimensions lhs_ptr, // LHS packed rhs_ptr, // RHS packed dst_ptr, // DST dst_stride, // DST stride (row) dst_stride_col, // DST stride (col) -FLT_MAX, FLT_MAX // Min and max for the clamp operation ); } } else if (data_type == DataType::TYPE_FP16) { #pragma omp parallel for for (int n_start = 0; n_start < n; n_start += n_step) { const size_t dst_stride = n * sizeof(float16_t); const size_t lhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl); const size_t rhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl); const size_t dst_offset = fp16_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride); const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset); const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset); float16_t* dst_ptr = (float16_t*)((uint8_t*)output->data() + dst_offset); int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; fp16_ukernel_variants[idx_variant].ukernel.run_matmul( m, tile_n, k, bl, // Dimensions lhs_ptr, // LHS packed rhs_ptr, // RHS packed dst_ptr, // DST dst_stride, // DST stride (row) sizeof(float16_t), // DST stride (col) -HALF_FLT_MAX, HALF_FLT_MAX // Min and max for the clamp operation ); } } delete[] lhs_packed_mtx_qs8d32; #ifdef GEMM_DEBUG auto end = std::chrono::high_resolution_clock::now(); float during_time = std::chrono::duration<float>(end - start).count(); printf("gemm_kai_a8w4 m,n,k %ld %ld %ld %.3f\n", m, n, k, during_time * 1000); #endif return output; } } // namespace rtp_llm