BufferPtr ArmCpuDevice::gemm_kai_a8w4()

in maga_transformer/cpp/devices/arm_impl/ArmGemmKaiOp.cc [299:475]


BufferPtr ArmCpuDevice::gemm_kai_a8w4(const GemmParams& params) {
#ifdef GEMM_DEBUG
    auto start = std::chrono::high_resolution_clock::now();
#endif
    params.check();

    std::vector<size_t> Ashape;
    std::vector<size_t> Bshape;
    std::vector<size_t> Dshape;

    size_t dim;
    size_t m;
    size_t k;
    size_t n;

    Ashape = params.A.shape();
    Bshape = params.B.shape();

    dim = params.A.dim();

    if (params.transA == TransposeOperation::TRANSPOSE) {
        std::iter_swap(Ashape.end() - 1, Ashape.end() - 2);
    }

    if (params.transB == TransposeOperation::TRANSPOSE) {
        std::iter_swap(Bshape.end() - 1, Bshape.end() - 2);
    }

    m = Ashape[dim - 2];
    k = Ashape[dim - 1];
    n = Bshape[dim - 1];

    auto data_type = params.compute_type == DataType::TYPE_INVALID ? params.A.type() : params.compute_type;

    Dshape = std::vector<size_t>(Ashape.begin(), Ashape.end() - 2);
    Dshape.insert(Dshape.end(), {m, n});

    BufferPtr output;
    if (params.D) {
        output = params.D;
        RUNTIME_ASSERT_OP_ARG((data_type == params.D->type()) && (Dshape == params.D->shape()),
                              "Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]",
                              data_type,
                              autil::StringUtil::toString(Dshape).c_str(),
                              params.D->debugString().c_str());
    } else {
        output = allocateBuffer({data_type, Dshape, AllocationType::DEVICE}, {"gemm_output"});
    }

    size_t idx_variant = 0;
    // input FP16 or output FP16 case, currently support gemv only
    if (m == 1) {
        idx_variant = 0;
    } else {
        idx_variant = 1;
    }

    // Get the packing parameters
    size_t mr;
    size_t kr;
    size_t sr;
    if (data_type == DataType::TYPE_FP32) {
        mr = fp32_ukernel_variants[idx_variant].ukernel.get_mr();
        kr = fp32_ukernel_variants[idx_variant].ukernel.get_kr();
        sr = fp32_ukernel_variants[idx_variant].ukernel.get_sr();
    } else if (data_type == DataType::TYPE_FP16) {
        mr = fp16_ukernel_variants[idx_variant].ukernel.get_mr();
        kr = fp16_ukernel_variants[idx_variant].ukernel.get_kr();
        sr = fp16_ukernel_variants[idx_variant].ukernel.get_sr();
    } else {
        RTP_LLM_LOG_WARNING("Not supported GEMM output type %d", data_type);
    }

    const size_t lhs_stride = k * sizeof(float);
    const size_t dst_stride_col = sizeof(float);

    const size_t bl = 32;

    const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32(m, k, bl, mr, kr, sr);
    uint8_t* lhs_packed_mtx_qs8d32 = new uint8_t[lhs_packed_size];

    uint8_t* rhs_packed_mtx_qs4c32 = (uint8_t*)params.B.data();
    float* lhs = (float* )params.A.data();

    int n_step = 32; // 32 is the best for performance
    int m_step = mr;
    // LHS packing
    if (params.A.type() == DataType::TYPE_FP32) {
	#pragma omp parallel for if (m > 1)
        for (int m_start = 0; m_start < m; m_start += m_step) {
            const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(m_start, lhs_stride);
            const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32(m_start, k, bl, mr, kr, sr);
            int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;

            kai_run_lhs_quant_pack_qsi8d32p_f32(
                tile_m, k, bl, mr, kr, sr, 0,
                (const float*)((uint8_t*)lhs + lhs_offset),
                lhs_stride,
                ((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset));
        }
    } else if (params.A.type() == DataType::TYPE_FP16) {
	#pragma omp parallel for if (m > 1)
        for (int m_start = 0; m_start < m; m_start += m_step) {
            const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k * sizeof(float16_t));
            const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k, bl, mr, kr, sr);
            int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;

            kai_run_lhs_quant_pack_qsi8d32p_f16(
                tile_m, k, bl, mr, kr, sr, 0,
                (const float16_t*)((uint8_t*)lhs + lhs_offset),
                k * sizeof(float16_t),
                ((uint8_t*)lhs_packed_mtx_qs8d32 + lhs_packed_offset));
        }
    } else {
        RTP_LLM_LOG_WARNING("Not supported GEMM A type %d", params.A.type());
    }

    // Matmul
    if (data_type == DataType::TYPE_FP32) {
        #pragma omp parallel for
        for (int n_start = 0; n_start < n; n_start += n_step) {
            const size_t dst_stride = n * sizeof(float);
            const size_t lhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl);
            const size_t rhs_offset = fp32_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl);
            const size_t dst_offset = fp32_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride);

            const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset);
            const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset);
            float* dst_ptr = (float*)((uint8_t*)output->data() + dst_offset);

            int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;

            fp32_ukernel_variants[idx_variant].ukernel.run_matmul(
                    m, tile_n, k, bl,  // Dimensions
                    lhs_ptr,           // LHS packed
                    rhs_ptr,           // RHS packed
                    dst_ptr,           // DST
                    dst_stride,        // DST stride (row)
                    dst_stride_col,    // DST stride (col)
                    -FLT_MAX, FLT_MAX  // Min and max for the clamp operation
                );
        }
    } else if (data_type == DataType::TYPE_FP16) {
        #pragma omp parallel for
        for (int n_start = 0; n_start < n; n_start += n_step) {
            const size_t dst_stride = n * sizeof(float16_t);
            const size_t lhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k, bl);
            const size_t rhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl);
            const size_t dst_offset = fp16_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, dst_stride);

            const void* lhs_ptr = (const void*)((const char *)lhs_packed_mtx_qs8d32 + lhs_offset);
            const void* rhs_ptr = (const void*)((const char *)rhs_packed_mtx_qs4c32 + rhs_offset);
            float16_t* dst_ptr = (float16_t*)((uint8_t*)output->data() + dst_offset);

            int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;

            fp16_ukernel_variants[idx_variant].ukernel.run_matmul(
                m, tile_n, k, bl,           // Dimensions
                lhs_ptr,                    // LHS packed
                rhs_ptr,                    // RHS packed
                dst_ptr,                    // DST
                dst_stride,                 // DST stride (row)
                sizeof(float16_t),          // DST stride (col)
                -HALF_FLT_MAX, HALF_FLT_MAX // Min and max for the clamp operation
            );
        }
    }

    delete[] lhs_packed_mtx_qs8d32;

#ifdef GEMM_DEBUG
    auto end = std::chrono::high_resolution_clock::now();
    float during_time = std::chrono::duration<float>(end - start).count();
    printf("gemm_kai_a8w4 m,n,k %ld %ld %ld %.3f\n", m, n, k, during_time * 1000);
#endif
    return output;
}