ConstBufferPtr prepareGemmOptForGPTQInt4()

in maga_transformer/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc [458:1021]


ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr scales, const std::string& key) {
    ConstBufferPtr weight_workspace = kernel;

    std::vector<size_t> Bshape = kernel->shape();
    auto dim = kernel->dim();

    size_t k;
    size_t n;

    k = Bshape[dim - 2];
    n = Bshape[dim - 1];

    n *= 2;

#if GPTQ_COMPUTE_AS_DI_BF16
    GemmKernel gemm_kernel;
    size_t weight_k_pack = std::ceil(k / 8.0) * 8;
    if (kernel->type() == DataType::TYPE_INT8 && scales->type() == DataType::TYPE_FP16) {
        int8_t* qweight = (int8_t*)kernel->data();
        auto qscales = (__fp16*)scales->data();
        __fp16* unpacked_weight = (__fp16*)malloc(k * n * 2);
        for (int i = 0; i < k; i++) {
            for (int j = 0; j < n; j += 2) {
                int8_t qint8 = qweight[i * (n / 2) + j / 2];
                __fp16 scale_0 = qscales[i / 128 * n + j ];
                __fp16 scale_1 = qscales[i / 128 * n + j + 1];

                auto elt_0 = qint8 & 0x0F;
                auto elt_1 = (qint8 >> 4) & 0x0F;
                if (elt_0 & 0x08) {
                    elt_0 -= 16;
                }
                if (elt_1 & 0x08) {
                    elt_1 -= 16;
                }

                auto x0 = scale_0 * elt_0;
                auto x1 = scale_1 * elt_1;

                unpacked_weight[i * n + j ] = x0;
                unpacked_weight[i * n + j + 1] = x1;
            }
        }

        std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
        weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n});

        size_t element_num = std::accumulate(Bshape.begin(), Bshape.end(), (size_t)1, std::multiplies<size_t>());
        element_num *= 2;

        const void *data = malloc(element_num * sizeof(hie::bfloat16));
        weight_workspace = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
                                                     DataType::TYPE_BF16,
                                                     weight_workspace_shape,
                                                     data)),
        memset(weight_workspace->data(), 0, weight_workspace->sizeBytes());

        hie::bfloat16* weight_workspace_cur_ptr = reinterpret_cast<hie::bfloat16*>(weight_workspace->data());

        gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, unpacked_weight, weight_workspace_cur_ptr);
        free(unpacked_weight);
        return weight_workspace;
#else
    if (kernel->type() == DataType::TYPE_INT8 && scales->type() == DataType::TYPE_FP16) {
        int8_t* qweight = (int8_t*)kernel->data();
        auto qscales = (__fp16*)scales->data();

        float* unpacked_weight = (float*)malloc(k * n * sizeof(float));
        #pragma omp parallel for collapse(2)
        for (int i = 0; i < k; i++) {
            for (int j = 0; j < n; j += 2) {
                int8_t qint8 = qweight[i * (n / 2) + j / 2];
                __fp16 scale_0 = qscales[i / 128 * n + j ];
                __fp16 scale_1 = qscales[i / 128 * n + j + 1];

                auto elt_0 = qint8 & 0x0F;
                auto elt_1 = (qint8 >> 4) & 0x0F;
                if (elt_0 & 0x08) {
                    elt_0 -= 16;
                }
                if (elt_1 & 0x08) {
                    elt_1 -= 16;
                }

                auto x0 = scale_0 * elt_0;
                auto x1 = scale_1 * elt_1;

                unpacked_weight[i * n + j ] = x0;
                unpacked_weight[i * n + j + 1] = x1;
            }
        }

        std::vector<size_t> input_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
        input_shape.insert(input_shape.end(), {k, n});
        BufferPtr input = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
                                                        DataType::TYPE_FP32,
                                                        input_shape,
                                                        unpacked_weight));

        auto transposedWeight = transposeWeight(input);

        const size_t bl = 32;
        const size_t num_blocks = k / bl;
        const size_t num_bytes_per_block_qs4c32 = (bl / 2) + sizeof(int16_t);
        const size_t rhs_native_size_qs4c32 = n * num_blocks * num_bytes_per_block_qs4c32;

        const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();
        const size_t kr = kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();
        const size_t sr = kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();

        // In a single row, we pack nr bias values followed by K rows of nr RHS values
        const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, nr, kr, bl);

        uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size];

        std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);

	weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n / 2});
        BufferPtr output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
                                                        DataType::TYPE_UINT8,
                                                        weight_workspace_shape,
                                                        rhs_packed_mtx_qs4c32));

        uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32];

        quant_qs4c32_f32(
                n, k, bl, (const float*)transposedWeight->data(), (uint8_t*)rhs_native_mtx_qs4c32);

        struct kai_rhs_pack_qs4cxs1s0_param kai_rhs_params;
        kai_rhs_params.lhs_zero_point = 1;
        kai_rhs_params.rhs_zero_point = 8;

        // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
        int n_step = 32;
        size_t rhs_stride = kai_rhs_stride(k, bl);

        #pragma omp parallel for
        for (int n_start = 0; n_start < n; n_start += n_step) {
            const size_t rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, rhs_stride);
            const size_t packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, k, nr, kr, bl);

            int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
            kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
                1, tile_n, k,                                           // Dimensions
                nr, kr, sr,                                             // Packing arguments
                bl,                                                     // Block length
                (const uint8_t*)(rhs_native_mtx_qs4c32 + rhs_offset),   // RHS
                NULL,                                                   // Bias
                ((uint8_t*)rhs_packed_mtx_qs4c32 + packed_offset),      // RHS packed
                0, &kai_rhs_params
            );
        }

        delete[] rhs_native_mtx_qs4c32;
        free(unpacked_weight);
        return output;
#endif
    }
    return weight_workspace;
}

void GemmKernel::pack_input_arm(int M, int N, int K, int lda, int K_pack, float* a_fp32, hie::bfloat16* a_bf16) {
    pack_input_impl_parallel_simd(M, N, K, lda, K_pack, a_fp32, a_bf16);
    return;
}

void GemmKernel::gemm_pack_weight_FP32toBF16_arm(int N, int K, int K_pack, const float* b_fp32, hie::bfloat16* b_bf16) {
    int k_tile   = 1024;  // empirical var: 1024, 5120
    int k_thread = std::ceil(K_pack * 1.0 / k_tile);

    parallel_for(k_thread, [&](int k) {
        for (int n = 0; n < N; n += 2) {
            float*         b_fp32_ptr1 = (float*)b_fp32 + k * k_tile * N + n + 0;
            float*         b_fp32_ptr2 = (float*)b_fp32 + k * k_tile * N + n + 1;
            hie::bfloat16* b_bf16_ptr  = b_bf16 + n * K_pack + k * k_tile * 2; // [n, k*k_tile*2]
            int            kk_max      = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
            for (int kk = k * k_tile; kk < kk_max; kk += 4) {
                for (int i = 0; i < 4 && (kk + i < kk_max); i++) {
                    b_bf16_ptr[i] = b_fp32_ptr1[i * N];
                    if (n != (N - 1)) {
                        b_bf16_ptr[i + 4] = b_fp32_ptr2[i * N];
                    }
                }
                b_bf16_ptr += 8;
                b_fp32_ptr1 += 4 * N;
                b_fp32_ptr2 += 4 * N;
            }
        }
    });

#ifdef PACK_DEBUG
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < K; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            printf("%f ", b_fp32[j * N + i]);
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");

    auto N_aligned = N / 2 + (N % 2);
    for (int i = 0; i < N_aligned; i++) {
        for (int j = 0; j < K_pack * 2; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            std::cout << std::setiosflags(std::ios::fixed) << std::setprecision(6) << b_bf16[i * K_pack * 2 + j] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");
#endif

    return;
}


void GemmKernel::gemm_pack_weight_FP16toBF16_arm(int N, int K, int K_pack, const float16_t* b_fp16, hie::bfloat16* b_bf16) {
    int k_tile   = 1024;  // empirical var: 1024, 5120
    int k_thread = std::ceil(K_pack * 1.0 / k_tile);

    parallel_for(k_thread, [&](int k) {
        for (int n = 0; n < N; n += 2) {
            float16_t*     b_fp16_ptr1 = (float16_t*)b_fp16 + k * k_tile * N + n + 0;
            float16_t*     b_fp16_ptr2 = (float16_t*)b_fp16 + k * k_tile * N + n + 1;
            hie::bfloat16* b_bf16_ptr  = b_bf16 + n * K_pack + k * k_tile * 2; // [n, k*k_tile*2]
            int            kk_max      = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
            for (int kk = k * k_tile; kk < kk_max; kk += 4) {
                for (int i = 0; i < 4 && (kk + i < kk_max); i++) {
                    b_bf16_ptr[i] = b_fp16_ptr1[i * N];
                    if (n != (N - 1)) {
                        b_bf16_ptr[i + 4] = b_fp16_ptr2[i * N];
                    }
                }
                b_bf16_ptr += 8;
                b_fp16_ptr1 += 4 * N;
                b_fp16_ptr2 += 4 * N;
            }
        }
    });


#ifdef PACK_DEBUG
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < K; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            std::cout << b_fp16[j * N + i] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");

    auto N_aligned = N / 2 + (N % 2);
    for (int i = 0; i < N_aligned; i++) {
        for (int j = 0; j < K_pack * 2; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            std::cout << std::setiosflags(std::ios::fixed) << std::setprecision(6) << b_bf16[i * K_pack * 2 + j] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");
#endif

    return;
}


void GemmKernel::pack_input_fp16tobf16_impl_parallel_simd(
    int M, int N, int K, int lda, int K_pack, float16_t* a_fp16, hie::bfloat16* a_bf16) {
#define LABEL_FOR_LOOP_M "0"
#define LABEL_FOR_LOOP_K "1"
#define LABEL_m_EQ_M_1 "2"
    int k_tile   = 1024;  // empirical var: 1024, 5120
    int k_thread = std::ceil(K * 1.0 / k_tile);

    // printf("k_tile: %d, k_thread: %d\n", k_tile, k_thread);

    // fp32 [ a[i,  j+0], a[i,  j+1], a[i,  j+2], a[i,  j+3] ]
    // fp32 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3] ]
    // bf16 [ a[i,  j+0], a[i,  j+1], a[i,  j+2], a[i,  j+3],
    //        a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3]]

    parallel_for(k_thread, [&](int k) {
        float16_t*     a_fp16_ptr1   = a_fp16 + 0 * lda + k * k_tile;
        float16_t*     a_fp16_ptr2   = a_fp16 + 1 * lda + k * k_tile;
        hie::bfloat16* a_bf16_ptr    = a_bf16 + k * k_tile * 2;
        int            a_fp16_offset = 2 * lda * sizeof(float16_t);
        int            a_bf16_offset = 2 * K_pack * sizeof(hie::bfloat16); // if K_pack % 16 == 8, for the remain 8 zero elements, use next line to cover it
        int            kk            = k * k_tile;
        int            kk_max        = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;

        // clang-format off
        asm volatile(
            "ptrue   p0.b                                    \n"
            "sub     x1,    %[M], #1                         \n"  // M - 1
            "mov     x2,    #0                               \n"  // m

            "" LABEL_FOR_LOOP_M
            ":\n"
            "mov     x3,    %[a_fp16_ptr1]                   \n"
            "mov     x4,    %[a_fp16_ptr2]                   \n"
            "mov     x5,    %[a_bf16_ptr]                    \n"

            "prfw    pldl1strm, p0, [x3,    #0, MUL VL]      \n"  // prefetch
            "prfw    pldl1strm, p0, [x4,    #0, MUL VL]      \n"

            "mov     x0,    %[kk]                            \n"
            "whilelt p1.h,  x0,   %[kk_max]                  \n"  // compare kk
                                                                  // and kk_max

            "" LABEL_FOR_LOOP_K
            ":\n"
            "ld1h   z4.h, p1/z, [x3,    #0, MUL VL]          \n" // load 8 fp16
            "dup    z6.h, #0                                 \n"
            "zip1   z0.h, z4.h, z6.h                         \n"  // zip 4(or less) fp16 values with 0
            "zip2   z1.h, z4.h, z6.h                         \n"  // zip 4(or less) fp16 values with 0
            "fcvt   z0.s, p0/m, z0.h                         \n"  // fp16 -> fp32
            "dup    z2.h, #0                                 \n"
            "fcvt   z1.s, p0/m, z1.h                         \n"  // fp16 -> fp32
            "dup    z3.h, #0                                 \n"
            "cmp    x2, x1                                   \n"  // compare m,
                                                                  // M - 1
            "b.none  " LABEL_m_EQ_M_1
            "f                     \n"
            "ld1h   z5.h, p1/z, [x4,    #0, MUL VL]          \n"  // load, when
                                                                  // m != M - 1
            "zip1   z2.h, z5.h, z6.h                         \n"  // zip 4(or less) fp16 values with 0
            "zip2   z3.h, z5.h, z6.h                         \n"  // zip 4(or less) fp16 values with 0
            "fcvt   z2.s, p0/m, z2.h                         \n"  // fp16 -> fp32
            "fcvt   z3.s, p0/m, z3.h                         \n"  // fp16 -> fp32

            "" LABEL_m_EQ_M_1
            ":\n"
            "add     x3, x3, #16                             \n"  // a_fp16_ptr1 += 8
            "add     x4, x4, #16                             \n"  // a_fp16_ptr2 += 8
            // "add     x3, x3, #8                              \n"  // a_fp16_ptr1 += 4
            // "add     x4, x4, #8                              \n"  // a_fp16_ptr2 += 4

            "prfw    pldl1strm, p0, [x3,    #0, MUL VL]      \n"
            "prfw    pldl1strm, p0, [x4,    #0, MUL VL]      \n"

            "bfcvt   z0.h, p0/m, z0.s                        \n"  // fp32 ->
                                                                  // bf16
            "bfcvt   z1.h, p0/m, z1.s                        \n"
            "bfcvt   z2.h, p0/m, z2.s                        \n"
            "bfcvt   z3.h, p0/m, z3.s                        \n"

            "uzp1    z4.h, z0.h, z2.h                        \n"  // combine
                                                                  // bf16
            "uzp1    z5.h, z1.h, z3.h                        \n"  // combine bf16
            "zip1    p3.d, p1.d, p1.d                        \n"  // cp 4 least significant half to 4 most significant half
            ""
            "st1h    z4.h, p3,   [x5, #0, MUL VL]            \n"  // store bf16 data

            "zip2    p3.d, p1.d, p1.d                        \n"  // cp 4 most significant half to 4 least significant half
            "st1h    z5.h, p3,   [x5, #1, MUL VL]            \n"  // store bf16
            "add     x5, x5, #32                             \n"  // a_bf16_ptr += 16
            // "add     x5, x5, #16                             \n"  // a_bf16_ptr += 8

            //   "prfw    pstl1keep, p0, [x5,    #0, MUL VL]      \n"

            "add     x0,    x0,   #8                         \n"  // kk += 8
            // "add     x0,    x0,   #4                         \n"  // kk += 4
            "whilelt p1.h,  x0,   %[kk_max]                  \n"  // compare kk
                                                                  // and kk_max
            "b.tstop " LABEL_FOR_LOOP_K
            "b                   \n"  // if k < K_MAX, go to label

            "add     %[a_fp16_ptr1], %[a_fp16_ptr1], %[a_fp16_offset] \n"
            "add     %[a_fp16_ptr2], %[a_fp16_ptr2], %[a_fp16_offset] \n"
            "add     %[a_bf16_ptr],  %[a_bf16_ptr],  %[a_bf16_offset] \n"
            "add     x2,    x2,   #2                         \n"  // m += 2
            "cmp     x2, %[M]                                \n"  // compare m,
                                                                  // M
            "b.tstop " LABEL_FOR_LOOP_M
            "b                   \n"  // if m < M, go to label

            : /* empty OutputOperands */
            : [a_fp16_ptr1] "r"(a_fp16_ptr1), [a_fp16_ptr2] "r"(a_fp16_ptr2),
              [a_bf16_ptr] "r"(a_bf16_ptr), [kk] "r"(kk), [kk_max] "r"(kk_max),
              [M] "r"(M), [a_fp16_offset] "r"(a_fp16_offset),
              [a_bf16_offset] "r"(a_bf16_offset)
            : "x0", "x1", "x2", "x3", "x4", "x5",
              "p0", "p1", "p2", "p3",
              "z0", "z1", "z2", "z3", "z4", "z5", "z6",
              "cc", "memory");
        // clang-format on
    });

#ifdef PACK_DEBUG
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < K; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            printf("%f ", a_fp16[i * lda + j]);
            // std::cout << a_fp16[i * lda + j] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");


    // int k_pack_compute = std::ceil(K / 16.0) * 16;
    auto M_aligned = M + (M % 2);
    for (int i = 0; i < M_aligned / 2; i++) {
        for (int j = 0; j < K_pack * 2; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            std::cout << a_bf16[i * K_pack * 2 + j] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");
#endif

    return;
}

void GemmKernel::pack_input_impl_parallel_simd(
    int M, int N, int K, int lda, int K_pack, float* a_fp32, hie::bfloat16* a_bf16) {
#define LABEL_FOR_LOOP_M "0"
#define LABEL_FOR_LOOP_K "1"
#define LABEL_m_EQ_M_1 "2"
    int k_tile   = 1024;  // empirical var: 1024, 5120
    int k_thread = std::ceil(K * 1.0 / k_tile);

    // printf("k_tile: %d, k_thread: %d\n", k_tile, k_thread);

    // fp32 [ a[i,  j+0], a[i,  j+1], a[i,  j+2], a[i,  j+3] ]
    // fp32 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3] ]
    // bf16 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3],
    //        a[i,  j+0], a[i,  j+1], a[i,  j+2], a[i,  j+3]] ???

    parallel_for(k_thread, [&](int k) {
        float*         a_fp32_ptr1   = a_fp32 + 0 * lda + k * k_tile;
        float*         a_fp32_ptr2   = a_fp32 + 1 * lda + k * k_tile;
        hie::bfloat16* a_bf16_ptr    = a_bf16 + k * k_tile * 2;
        int            a_fp32_offset = 2 * lda * sizeof(float);
        int            a_bf16_offset = 2 * K_pack * sizeof(hie::bfloat16);
        int            kk            = k * k_tile;
        int            kk_max        = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;

        // clang-format off
        asm volatile(
            "ptrue   p0.b                                    \n"
            "sub     x1,    %[M], #1                         \n"  // M - 1
            "mov     x2,    #0                               \n"  // m

            "" LABEL_FOR_LOOP_M
            ":\n"
            "mov     x3,    %[a_fp32_ptr1]                   \n"
            "mov     x4,    %[a_fp32_ptr2]                   \n"
            "mov     x5,    %[a_bf16_ptr]                    \n"

            "prfw    pldl1strm, p0, [x3,    #0, MUL VL]      \n"  // prefetch
            "prfw    pldl1strm, p0, [x4,    #0, MUL VL]      \n"

            "mov     x0,    %[kk]                            \n"
            "whilelt p1.s,  x0,   %[kk_max]                  \n"  // compare kk
                                                                  // and kk_max

            "" LABEL_FOR_LOOP_K
            ":\n"
            "ld1w   z0.s, p1/z, [x3,    #0, MUL VL]          \n"
            "dup    z1.h, #0                                 \n"
            "cmp    x2, x1                                   \n"  // compare m,
                                                                  // M - 1
            "b.none  " LABEL_m_EQ_M_1
            "f                     \n"
            "ld1w   z1.s, p1/z, [x4,    #0, MUL VL]          \n"  // load, when
                                                                  // m != M - 1

            "" LABEL_m_EQ_M_1
            ":\n"
            "add     x3, x3, #16                             \n"
            "add     x4, x4, #16                             \n"

            "prfw    pldl1strm, p0, [x3,    #0, MUL VL]      \n"
            "prfw    pldl1strm, p0, [x4,    #0, MUL VL]      \n"

            "bfcvt   z0.h, p0/m, z0.s                        \n"  // fp32 ->
                                                                  // bf16
            "bfcvt   z1.h, p0/m, z1.s                        \n"
            "uzp1    z2.h, z0.h, z1.h                        \n"  // combine
                                                                  // bf16

            "uzp1    p3.h, p1.h, p1.h                        \n"
            "st1h    z2.h, p3,   [x5, #0, MUL VL]            \n"  // store bf16
                                                                  // data
            "add     x5, x5, #16                             \n"

            //   "prfw    pstl1keep, p0, [x5,    #0, MUL VL]      \n"

            "add     x0,    x0,   #4                         \n"  // kk += 4
            "whilelt p1.s,  x0,   %[kk_max]                  \n"  // compare kk
                                                                  // and kk_max
            "b.tstop " LABEL_FOR_LOOP_K
            "b                   \n"  // if k < K_MAX, go to label

            "add     %[a_fp32_ptr1], %[a_fp32_ptr1], %[a_fp32_offset] \n"
            "add     %[a_fp32_ptr2], %[a_fp32_ptr2], %[a_fp32_offset] \n"
            "add     %[a_bf16_ptr],  %[a_bf16_ptr],  %[a_bf16_offset] \n"
            "add     x2,    x2,   #2                         \n"  // m += 2
            "cmp     x2, %[M]                                \n"  // compare m,
                                                                  // M
            "b.tstop " LABEL_FOR_LOOP_M
            "b                   \n"  // if m < M, go to label

            : /* empty OutputOperands */
            : [a_fp32_ptr1] "r"(a_fp32_ptr1), [a_fp32_ptr2] "r"(a_fp32_ptr2),
              [a_bf16_ptr] "r"(a_bf16_ptr), [kk] "r"(kk), [kk_max] "r"(kk_max),
              [M] "r"(M), [a_fp32_offset] "r"(a_fp32_offset),
              [a_bf16_offset] "r"(a_bf16_offset)
            : "x0", "x1", "x2", "x3", "x4", "x5", "p0", "p1", "p2", "p3", "z0",
              "z1", "z2", "cc", "memory");
        // clang-format on
    });

#ifdef PACK_DEBUG
    for (int i = 0; i < M; i++) {
        for (int j = 0; j < K; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            printf("%f ", a_fp32[i * lda + j]);
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");


    auto M_aligned = M + (M % 2);
    for (int i = 0; i < M_aligned / 2; i++) {
        for (int j = 0; j < K_pack * 2; j++) {
            if (j % 8 == 0) {
                printf("\n");
            }
            std::cout << a_bf16[i * K_pack * 2 + j] << " ";
        }
        printf("\n");
        printf("\n");
    }
    printf("\n");
#endif

    return;
}

}  // namespace rtp_llm