void MNNPackedSparseQuantMatMulEpx4()

in source/backend/cpu/compute/Int8FunctionsOpt.cpp [581:1404]


void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {

    size_t eSize = sparseQuantParam[0];
    size_t eP = sparseQuantParam[1];
    size_t aStride = sparseQuantParam[2];
    size_t l = sparseQuantParam[3];
    size_t h = sparseQuantParam[4];
    size_t cStride = sparseQuantParam[5];

    const int32_t* bias = post->bias;
    const float* scales = post->scale;
    const int32_t maxValue = post->maxValue;
    const int32_t minValue = post->minValue;

    const int sparseBlockOC = 4;
    const int8_t * a = A;
    size_t ie = 0;
    for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const int8_t * w = B;
        int8_t * blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
            auto ihPack = ih >> 2;
            auto c = blockC + ihPack * cStride;

            int32_t initValue[4] = {0, 0, 0, 0};
            if (nullptr != bias) {
                memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
            }
            int32_t acc0[4];
            int32_t acc1[4];
            int32_t acc2[4];
            int32_t acc3[4];
            int32_t acc4[4];
            int32_t acc5[4];
            int32_t acc6[4];
            int32_t acc7[4];
            int32_t acc8[4];
            int32_t acc9[4];
            int32_t acc10[4];
            int32_t acc11[4];
            int32_t acc12[4];
            int32_t acc13[4];
            int32_t acc14[4];
            int32_t acc15[4];

            memcpy(acc0, initValue, 4 * sizeof(int32_t));
            memcpy(acc1, initValue, 4 * sizeof(int32_t));
            memcpy(acc2, initValue, 4 * sizeof(int32_t));
            memcpy(acc3, initValue, 4 * sizeof(int32_t));
            memcpy(acc4, initValue, 4 * sizeof(int32_t));
            memcpy(acc5, initValue, 4 * sizeof(int32_t));
            memcpy(acc6, initValue, 4 * sizeof(int32_t));
            memcpy(acc7, initValue, 4 * sizeof(int32_t));
            memcpy(acc8, initValue, 4 * sizeof(int32_t));
            memcpy(acc9, initValue, 4 * sizeof(int32_t));
            memcpy(acc10, initValue, 4 * sizeof(int32_t));
            memcpy(acc11, initValue, 4 * sizeof(int32_t));
            memcpy(acc12, initValue, 4 * sizeof(int32_t));
            memcpy(acc13, initValue, 4 * sizeof(int32_t));
            memcpy(acc14, initValue, 4 * sizeof(int32_t));
            memcpy(acc15, initValue, 4 * sizeof(int32_t));

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t a4 = a[4];
                const int8_t a5 = a[5];
                const int8_t a6 = a[6];
                const int8_t a7 = a[7];
                const int8_t a8 = a[8];
                const int8_t a9 = a[9];
                const int8_t a10 = a[10];
                const int8_t a11 = a[11];
                const int8_t a12 = a[12];
                const int8_t a13 = a[13];
                const int8_t a14 = a[14];
                const int8_t a15 = a[15];

                const int8_t wv[4] = {*w++, *w++, *w++, *w++};

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                a = a + diff;
                for (int lane = 0; lane < 4; lane++) {
                    acc0[lane] += (int32_t)a0 * (int32_t)wv[lane];
                    acc1[lane] += (int32_t)a1 * (int32_t)wv[lane];
                    acc2[lane] += (int32_t)a2 * (int32_t)wv[lane];
                    acc3[lane] += (int32_t)a3 * (int32_t)wv[lane];
                    acc4[lane] += (int32_t)a4 * (int32_t)wv[lane];
                    acc5[lane] += (int32_t)a5 * (int32_t)wv[lane];
                    acc6[lane] += (int32_t)a6 * (int32_t)wv[lane];
                    acc7[lane] += (int32_t)a7 * (int32_t)wv[lane];
                    acc8[lane] += (int32_t)a8 * (int32_t)wv[lane];
                    acc9[lane] += (int32_t)a9 * (int32_t)wv[lane];
                    acc10[lane] += (int32_t)a10 * (int32_t)wv[lane];
                    acc11[lane] += (int32_t)a11 * (int32_t)wv[lane];
                    acc12[lane] += (int32_t)a12 * (int32_t)wv[lane];
                    acc13[lane] += (int32_t)a13 * (int32_t)wv[lane];
                    acc14[lane] += (int32_t)a14 * (int32_t)wv[lane];
                    acc15[lane] += (int32_t)a15 * (int32_t)wv[lane];
                }
            }

            int8_t result0[4];
            int8_t result1[4];
            int8_t result2[4];
            int8_t result3[4];
            int8_t result4[4];
            int8_t result5[4];
            int8_t result6[4];
            int8_t result7[4];
            int8_t result8[4];
            int8_t result9[4];
            int8_t result10[4];
            int8_t result11[4];
            int8_t result12[4];
            int8_t result13[4];
            int8_t result14[4];
            int8_t result15[4];

            if (scales) {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
                    result4[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
                    result5[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
                    result6[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
                    result7[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
                    result8[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc8[lane])), float(minValue))));
                    result9[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc9[lane])), float(minValue))));
                    result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc10[lane])), float(minValue))));
                    result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc11[lane])), float(minValue))));
                    result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc12[lane])), float(minValue))));
                    result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc13[lane])), float(minValue))));
                    result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc14[lane])), float(minValue))));
                    result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc15[lane])), float(minValue))));
                }
            } else {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
                    result4[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
                    result5[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
                    result6[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
                    result7[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
                    result8[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc8[lane]), minValue)));
                    result9[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc9[lane]), minValue)));
                    result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc10[lane]), minValue)));
                    result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc11[lane]), minValue)));
                    result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc12[lane]), minValue)));
                    result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc13[lane]), minValue)));
                    result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc14[lane]), minValue)));
                    result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc15[lane]), minValue)));
                }
            }

            memcpy(c         , result0, 4 * sizeof(int8_t));  // store continuous c
            memcpy(c + 4     , result1, 4 * sizeof(int8_t));
            memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
            memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
            memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
            memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
            memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
            memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));
            memcpy(c + 4 * 8 , result8, 4 * sizeof(int8_t));
            memcpy(c + 4 * 9 , result9, 4 * sizeof(int8_t));
            memcpy(c + 4 * 10, result10, 4 * sizeof(int8_t));
            memcpy(c + 4 * 11, result11, 4 * sizeof(int8_t));
            memcpy(c + 4 * 12, result12, 4 * sizeof(int8_t));
            memcpy(c + 4 * 13, result13, 4 * sizeof(int8_t));
            memcpy(c + 4 * 14, result14, 4 * sizeof(int8_t));
            memcpy(c + 4 * 15, result15, 4 * sizeof(int8_t));
        }

        blockC += (h >> 2) * cStride;
        for (; ih < h; ih++) {
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihSubIndex;
            const int32_t initValue = nullptr != bias ? bias[ih] : 0;
            int32_t acc0 = initValue;
            int32_t acc1 = initValue;
            int32_t acc2 = initValue;
            int32_t acc3 = initValue;
            int32_t acc4 = initValue;
            int32_t acc5 = initValue;
            int32_t acc6 = initValue;
            int32_t acc7 = initValue;
            int32_t acc8 = initValue;
            int32_t acc9 = initValue;
            int32_t acc10 = initValue;
            int32_t acc11 = initValue;
            int32_t acc12 = initValue;
            int32_t acc13 = initValue;
            int32_t acc14 = initValue;
            int32_t acc15 = initValue;
            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t a4 = a[4];
                const int8_t a5 = a[5];
                const int8_t a6 = a[6];
                const int8_t a7 = a[7];
                const int8_t a8 = a[8];
                const int8_t a9 = a[9];
                const int8_t a10 = a[10];
                const int8_t a11 = a[11];
                const int8_t a12 = a[12];
                const int8_t a13 = a[13];
                const int8_t a14 = a[14];
                const int8_t a15 = a[15];

                const int8_t oneW = *w++;

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += (int32_t)a0 * (int32_t)oneW;
                acc1 += (int32_t)a1 * (int32_t)oneW;
                acc2 += (int32_t)a2 * (int32_t)oneW;
                acc3 += (int32_t)a3 * (int32_t)oneW;
                acc4 += (int32_t)a4 * (int32_t)oneW;
                acc5 += (int32_t)a5 * (int32_t)oneW;
                acc6 += (int32_t)a6 * (int32_t)oneW;
                acc7 += (int32_t)a7 * (int32_t)oneW;
                acc8 += (int32_t)a8 * (int32_t)oneW;
                acc9 += (int32_t)a9 * (int32_t)oneW;
                acc10 += (int32_t)a10 * (int32_t)oneW;
                acc11 += (int32_t)a11 * (int32_t)oneW;
                acc12 += (int32_t)a12 * (int32_t)oneW;
                acc13 += (int32_t)a13 * (int32_t)oneW;
                acc14 += (int32_t)a14 * (int32_t)oneW;
                acc15 += (int32_t)a15 * (int32_t)oneW;
            }

            int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit
            int8_t result1;
            int8_t result2;
            int8_t result3;
            int8_t result4;
            int8_t result5;
            int8_t result6;
            int8_t result7;
            int8_t result8;
            int8_t result9;
            int8_t result10;
            int8_t result11;
            int8_t result12;
            int8_t result13;
            int8_t result14;
            int8_t result15;

            if (scales) {
                result0  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
                result1  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
                result2  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
                result3  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
                result4  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
                result5  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
                result6  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
                result7  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
                result8  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue))));
                result9  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue))));
                result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue))));
                result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue))));
                result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue))));
                result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue))));
                result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue))));
                result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue))));
            } else {
                result0  = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
                result1  = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
                result2  = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
                result3  = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
                result4  = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
                result5  = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
                result6  = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
                result7  = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
                result8  = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue));
                result9  = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue));
                result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue));
                result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue));
                result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue));
                result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue));
                result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue));
                result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue));
            }

            // how to store faster: st4 / transpose /
            c[0] = result0;
            c[4] = result1;
            c[4 * 2] = result2;
            c[4 * 3] = result3;
            c[4 * 4] = result4;
            c[4 * 5] = result5;
            c[4 * 6] = result6;
            c[4 * 7] = result7;
            c[4 * 8] = result8;
            c[4 * 9] = result9;
            c[4 * 10] = result10;
            c[4 * 11] = result11;
            c[4 * 12] = result12;
            c[4 * 13] = result13;
            c[4 * 14] = result14;
            c[4 * 15] = result15;
        }
        a += aStride;
    }
    if (eSize & 0x08) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // a = blockA + diff;
        a += diff;
        const int8_t* w = B;
        int8_t* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
            auto ihPack = ih >> 2;
            auto c = blockC + ihPack * cStride;
            int32_t initValue[4] = {0, 0, 0, 0};
            if (nullptr != bias) {
                memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
            }
            int32_t acc0[4];
            int32_t acc1[4];
            int32_t acc2[4];
            int32_t acc3[4];
            int32_t acc4[4];
            int32_t acc5[4];
            int32_t acc6[4];
            int32_t acc7[4];

            memcpy(acc0, initValue, 4 * sizeof(int32_t));
            memcpy(acc1, initValue, 4 * sizeof(int32_t));
            memcpy(acc2, initValue, 4 * sizeof(int32_t));
            memcpy(acc3, initValue, 4 * sizeof(int32_t));
            memcpy(acc4, initValue, 4 * sizeof(int32_t));
            memcpy(acc5, initValue, 4 * sizeof(int32_t));
            memcpy(acc6, initValue, 4 * sizeof(int32_t));
            memcpy(acc7, initValue, 4 * sizeof(int32_t));

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t a4 = a[4];
                const int8_t a5 = a[5];
                const int8_t a6 = a[6];
                const int8_t a7 = a[7];
                const int8_t wv[4] = {*w++, *w++, *w++, *w++};
                // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value[0-3]:, a value[0-7]:\n", ie, a - A, w - B - 1, c - C);
                // formatMatrix(wv, {4});
                // formatMatrix(a, {8});
                // MNN_PRINT("\n");
                a = a + diff;
                for (int lane = 0; lane < 4; lane++) {
                    acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
                    acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
                    acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
                    acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
                    acc4[lane] += int32_t(a4) * int32_t(wv[lane]);
                    acc5[lane] += int32_t(a5) * int32_t(wv[lane]);
                    acc6[lane] += int32_t(a6) * int32_t(wv[lane]);
                    acc7[lane] += int32_t(a7) * int32_t(wv[lane]);
                }
            }

            int8_t result0[4];
            int8_t result1[4];
            int8_t result2[4];
            int8_t result3[4];
            int8_t result4[4];
            int8_t result5[4];
            int8_t result6[4];
            int8_t result7[4];

            if (scales) {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
                    result4[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
                    result5[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
                    result6[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
                    result7[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
                }
            } else {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
                    result4[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
                    result5[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
                    result6[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
                    result7[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
                }
            }

            memcpy(c         , result0, 4 * sizeof(int8_t));  // store continuous c
            memcpy(c + 4     , result1, 4 * sizeof(int8_t));
            memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
            memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
            memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
            memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
            memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
            memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));

        }
        blockC += (ih >> 2) * cStride;
        for (; ih < h; ih++) {
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihSubIndex;
            const int32_t initValue = nullptr != bias ? bias[ih] : 0;
            int32_t acc0 = initValue;
            int32_t acc1 = initValue;
            int32_t acc2 = initValue;
            int32_t acc3 = initValue;
            int32_t acc4 = initValue;
            int32_t acc5 = initValue;
            int32_t acc6 = initValue;
            int32_t acc7 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t a4 = a[4];
                const int8_t a5 = a[5];
                const int8_t a6 = a[6];
                const int8_t a7 = a[7];
                const int8_t oneW = *w++;
                // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {8});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += int32_t(a0) * int32_t(oneW);
                acc1 += int32_t(a1) * int32_t(oneW);
                acc2 += int32_t(a2) * int32_t(oneW);
                acc3 += int32_t(a3) * int32_t(oneW);
                acc4 += int32_t(a4) * int32_t(oneW);
                acc5 += int32_t(a5) * int32_t(oneW);
                acc6 += int32_t(a6) * int32_t(oneW);
                acc7 += int32_t(a7) * int32_t(oneW);
            }

            int8_t result0;
            int8_t result1;
            int8_t result2;
            int8_t result3;
            int8_t result4;
            int8_t result5;
            int8_t result6;
            int8_t result7;
            if (scales) {
                result0  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
                result1  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
                result2  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
                result3  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
                result4  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
                result5  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
                result6  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
                result7  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));

            } else {
                result0  = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
                result1  = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
                result2  = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
                result3  = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
                result4  = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
                result5  = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
                result6  = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
                result7  = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
            }

            // how to store faster: st4 / transpose /
            c[0] = result0;
            c[4] = result1;
            c[4 * 2] = result2;
            c[4 * 3] = result3;
            c[4 * 4] = result4;
            c[4 * 5] = result5;
            c[4 * 6] = result6;
            c[4 * 7] = result7;
        }
        ie += 8;
        a += 8;
    }
    if (eSize & 0x04) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // a = blockA + diff;
        a += diff;
        const int8_t* w = B;
        int8_t* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
            auto ihPack = ih >> 2;
            auto c = blockC + ihPack * cStride;
            int32_t initValue[4] = {0, 0, 0, 0};
            if (nullptr != bias) {
                memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
            }
            int32_t acc0[4];
            int32_t acc1[4];
            int32_t acc2[4];
            int32_t acc3[4];

            memcpy(acc0, initValue, 4 * sizeof(int32_t));
            memcpy(acc1, initValue, 4 * sizeof(int32_t));
            memcpy(acc2, initValue, 4 * sizeof(int32_t));
            memcpy(acc3, initValue, 4 * sizeof(int32_t));

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t wv[4] = {*w++, *w++, *w++, *w++};
                // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-3]:\n", ie, a - A, w - B - 1, c - C);
                // formatMatrix(wv, {4});
                // formatMatrix(a, {4});
                // MNN_PRINT("\n");
                a = a + diff;
                for (int lane = 0; lane < 4; lane++) {
                    acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
                    acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
                    acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
                    acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
                }
            }

            int8_t result0[4];
            int8_t result1[4];
            int8_t result2[4];
            int8_t result3[4];

            if (scales) {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
                }
            } else {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
                    result2[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
                    result3[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
                }
            }

            memcpy(c         , result0, 4 * sizeof(int8_t));  // store continuous c
            memcpy(c + 4     , result1, 4 * sizeof(int8_t));
            memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
            memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));

        }
        blockC += (ih >> 2) * cStride;
        for (; ih < h; ih++) {
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihSubIndex;
            const int32_t initValue = nullptr != bias ? bias[ih] : 0;
            int32_t acc0 = initValue;
            int32_t acc1 = initValue;
            int32_t acc2 = initValue;
            int32_t acc3 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t a2 = a[2];
                const int8_t a3 = a[3];
                const int8_t oneW = *w++;
                // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {4});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += int32_t(a0) * int32_t(oneW);
                acc1 += int32_t(a1) * int32_t(oneW);
                acc2 += int32_t(a2) * int32_t(oneW);
                acc3 += int32_t(a3) * int32_t(oneW);
            }

            int8_t result0;
            int8_t result1;
            int8_t result2;
            int8_t result3;
            if (scales) {
                result0  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
                result1  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
                result2  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
                result3  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
            } else {
                result0  = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
                result1  = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
                result2  = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
                result3  = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
            }

            // how to store faster: st4 / transpose /
            c[0] = result0;
            c[4] = result1;
            c[4 * 2] = result2;
            c[4 * 3] = result3;
        }
        ie += 4;
        a += 4;
    }
    if (eSize & 0x02) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // a = blockA + diff;
        a += diff;
        const int8_t* w = B;
        int8_t* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
            auto ihPack = ih >> 2;
            auto c = blockC + ihPack * cStride;
            int32_t initValue[4] = {0, 0, 0, 0};
            if (nullptr != bias) {
                memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
            }
            int32_t acc0[4];
            int32_t acc1[4];
            memcpy(acc0, initValue, 4 * sizeof(int32_t));
            memcpy(acc1, initValue, 4 * sizeof(int32_t));

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t wv[4] = {*w++, *w++, *w++, *w++};
                // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
                // formatMatrix(wv, {4});
                // formatMatrix(a, {2});
                // MNN_PRINT("\n");
                a = a + diff;
                for (int lane = 0; lane < 4; lane++) {
                    acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
                    acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
                }
            }

            int8_t result0[4];
            int8_t result1[4];
            if (scales) {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
                }
            } else {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
                    result1[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
                }
            }

            memcpy(c         , result0, 4 * sizeof(int8_t));  // store continuous c
            memcpy(c + 4     , result1, 4 * sizeof(int8_t));
        }
        blockC += (ih >> 2) * cStride;
        for (; ih < h; ih++) {
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihSubIndex;
            const int32_t initValue = nullptr != bias ? bias[ih] : 0;
            int32_t acc0 = initValue;
            int32_t acc1 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t a1 = a[1];
                const int8_t oneW = *w++;
                // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {2});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += int32_t(a0) * int32_t(oneW);
                acc1 += int32_t(a1) * int32_t(oneW);
            }

            int8_t result0;
            int8_t result1;
            if (scales) {
                result0  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
                result1  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
            } else {
                result0  = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
                result1  = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
            }

            // how to store faster: st4 / transpose /
            c[0] = result0;
            c[4] = result1;
        }
        ie += 2;
        a += 2;
    }
    if (eSize & 0x01) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // const float* a = blockA + diff;
        a += diff;
        const int8_t * w = B;
        int8_t * blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
            auto ihPack = ih >> 2;
            auto c = blockC + ihPack * cStride;
            int32_t initValue[4] = {0, 0, 0, 0};
            if (nullptr != bias) {
                memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
            }
            int32_t acc0[4];
            memcpy(acc0, initValue, 4 * sizeof(int32_t));
            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t wv[4] = {*w++, *w++, *w++, *w++};
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
                // formatMatrix(wv, {4});
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                a = a + diff;
                for (int lane = 0; lane < 4; lane++) {
                    acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
                }
            }

            int8_t result0[4];
            if (scales) {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
                }
            } else {
                for (int lane = 0; lane < 4; lane++) {
                    result0[lane]  = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
                }
            }
            memcpy(c, result0, 4 * sizeof(int8_t));  // store continuous c
        }
        blockC += (ih >> 2) * cStride;
        for (; ih < h; ih++) {
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihSubIndex;
            const int32_t initValue = nullptr != bias ? bias[ih] : 0;
            int32_t acc0 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const int8_t a0 = a[0];
                const int8_t oneW = *w++;

                // MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {1});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += int32_t(a0) * int32_t(oneW);
            }
            int8_t result0;
            if (scales) {
                result0  = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
            } else {
                result0  = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
            }
            // how to store faster: st4 / transpose /
            c[0] = result0;
        }
        ie += 1;
        // a += 1;
    }

}