BufferPtr ArmCpuDevice::gemm_kai_bf16()

in maga_transformer/cpp/devices/arm_impl/ArmGemmKaiOp.cc [103:297]


BufferPtr ArmCpuDevice::gemm_kai_bf16(const GemmParams& params) {
#ifdef GEMM_DEBUG
    auto start = std::chrono::high_resolution_clock::now();
#endif
    params.check();

    std::vector<size_t> Ashape;
    std::vector<size_t> Bshape;
    std::vector<size_t> Dshape;

    size_t dim;
    size_t m;
    size_t k;
    size_t n;

    Ashape = params.A.shape();
    Bshape = params.B.shape();

    dim = params.A.dim();

    if (params.transA == TransposeOperation::TRANSPOSE) {
        std::iter_swap(Ashape.end() - 1, Ashape.end() - 2);
    }

    if (params.transB == TransposeOperation::TRANSPOSE) {
        std::iter_swap(Bshape.end() - 1, Bshape.end() - 2);
    }

    m = Ashape[dim - 2];
    k = Ashape[dim - 1];
    n = Bshape[dim - 1];

    auto data_type = params.compute_type == DataType::TYPE_INVALID ? params.A.type() : params.compute_type;

    Dshape = std::vector<size_t>(Ashape.begin(), Ashape.end() - 2);
    Dshape.insert(Dshape.end(), {m, n});

    BufferPtr output;
    if (params.D) {
        output = params.D;
        RUNTIME_ASSERT_OP_ARG((data_type == params.D->type()) && (Dshape == params.D->shape()),
                              "Gemm output D shape and dtype mismatch: expected [%d][%s] but got [%s]",
                              data_type,
                              autil::StringUtil::toString(Dshape).c_str(),
                              params.D->debugString().c_str());
    } else {
        output = allocateBuffer({data_type, Dshape, AllocationType::DEVICE}, {"gemm_output"});
    }

    const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
    const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
    const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
    const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();

    uint8_t* rhs_packed;
    uint8_t* lhs_packed;

    float* lhs = (float* )params.A.data();

    // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
    int n_step = nr;
    rhs_packed = (uint8_t* )params.B.data();
    float* dst = (float* )output->data();

    int m_step = mr;

    if (params.A.type() == DataType::TYPE_FP32) {
        // lhs in fp32
        const size_t lhs_stride = k * sizeof(float);
        const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(m, k, mr, kr, sr);
        lhs_packed = new uint8_t[lhs_packed_size];

	#pragma omp parallel for if (m > 1)
        for (int m_start = 0; m_start < m; m_start += m_step) {
            const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon(m_start, lhs_stride);
            const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_bf16p8x4_f32_neon(m_start, k, mr, kr, sr);
            int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;

            kai_run_lhs_quant_pack_bf16p8x4_f32_neon(
                tile_m, k, mr, kr, sr,
                0 /* m_idx_start; should stay as 0 */,
                ((uint8_t*)lhs + lhs_offset), // adjust Lhs start position
                lhs_stride,
                (lhs_packed + lhs_packed_offset));
        }
    } else if (params.A.type() == DataType::TYPE_FP16) {
        // lhs in fp16
        const size_t lhs_stride = k * sizeof(float16_t);
        const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon(m, k, mr, kr, sr);
        lhs_packed = new uint8_t[lhs_packed_size];

	#pragma omp parallel for if (m > 1)
        for (int m_start = 0; m_start < m; m_start += m_step) {
            const size_t lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon(m_start, lhs_stride);
            const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(m_start, k, mr, kr, sr);
            int tile_m = (m_start + m_step <= m) ? m_step : m - m_start;

            kai_run_lhs_pack_bf16p8x4_f16_neon(
                tile_m, k, mr, kr, sr,
                0 /* m_idx_start; should stay as 0 */,
                ((uint8_t*)lhs + lhs_offset), // adjust Lhs start position
                lhs_stride,
                (lhs_packed + lhs_packed_offset));
        }
    } else {
        RTP_LLM_LOG_WARNING("Not supported GEMM input type %d", params.A.type());
    }

    if (data_type == DataType::TYPE_FP32) {
        // matmul out fp32
        const size_t dst_stride_row = n * sizeof(float);
        const size_t dst_stride_col = sizeof(float);

        #pragma omp parallel for
        for (int n_start = 0; n_start < n; n_start += n_step) {
            size_t lhs_offset;
            size_t rhs_offset;
            size_t dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(0, n_start, n * sizeof(float));
            if (params.A.type() == DataType::TYPE_FP32) {
                lhs_offset = kai_get_lhs_packed_offset_lhs_quant_pack_bf16p8x4_f32_neon(0, k, mr, kr, sr);
                rhs_offset = kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n_start, k, nr, kr);
            } else { // For input type FP16 and compute type FP32.
                lhs_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(0, k, mr, kr, sr);
                rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n_start, k);
            }

            const void* lhs_ptr = (const void*)((const uint8_t*)lhs_packed + lhs_offset);
            const void* rhs_ptr = (const void*)((const uint8_t*)rhs_packed + rhs_offset);
            void* dst_ptr = (void*)((uint8_t*)dst + dst_offset);

            assert(n % n_step == 0);
            assert(n_step % n_step == 0);

            int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
            kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(
                m, tile_n, k,                  // Dimensions
                lhs_ptr,                      // LHS
                rhs_ptr,               // RHS packed
                dst_ptr,                      // DST
                dst_stride_row,           // DST stride (row)
                dst_stride_col,           // DST stride (col)
                -FLT_MAX, FLT_MAX   // Min and max for the clamp operation
            );
        }
    } else if (data_type == DataType::TYPE_FP16) {
        // matmul out fp16

        const size_t dst_stride_row = n * sizeof(float16_t);
        const size_t dst_stride_col = sizeof(float16_t);

        #pragma omp parallel for
        for (int n_start = 0; n_start < n; n_start += n_step) {
            size_t lhs_offset = kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(0, k, mr, kr, sr);
            size_t rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n_start, k);
            size_t dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla(0, n_start, n * sizeof(bfloat16_t));

            const void* lhs_ptr = (const void*)((const uint8_t*)lhs_packed + lhs_offset);
            const void* rhs_ptr = (const void*)((const uint8_t*)rhs_packed + rhs_offset);
            void* dst_ptr = (void*)((uint8_t*)dst + dst_offset);

            assert(n % n_step == 0);
            assert(n_step % n_step == 0);

            int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
            kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla(
                m, tile_n, k,                  // Dimensions
                lhs_ptr,                      // LHS
                rhs_ptr,               // RHS packed
                dst_ptr,                      // DST
                dst_stride_row,           // DST stride (row)
                dst_stride_col,           // DST stride (col)
                -FLT_MAX, FLT_MAX   // Min and max for the clamp operation
            );
        }
    } else {
        RTP_LLM_LOG_WARNING("Not supported GEMM output type %d", data_type);
    }

    delete[] lhs_packed;

    /* TODO 
    if (m == 1) {
        // gemv
    } else {
        // gemm
    }
    */

#ifdef GEMM_DEBUG
    auto end = std::chrono::high_resolution_clock::now();
    float during_time = std::chrono::duration<float>(end - start).count();
    printf("gemm_kai_bf16 m,n,k %ld %ld %ld %.3f\n", m, n, k, during_time * 1000);
#endif
    return output;
}