void _AVX512_MNNPackedSparseMatMulEpx4()

in source/backend/cpu/x86_x64/avx512/SparseKernelFunctionEpx4.cpp [19:811]


void _AVX512_MNNPackedSparseMatMulEpx4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
                                    const float* postParameters, const float* bias, unsigned int* NNZMap,
                                    int* dataOffsetMap) {

    auto eP = parameter[0] / sizeof(float);
    auto h = parameter[2];
    auto l = parameter[1];
    auto cStride = parameter[3] / sizeof(float);
    auto hRemain = parameter[4];
    auto bExtraStride = parameter[5] / sizeof(float);
    auto aStride = eP * l; // sizeof(float);

    constexpr size_t packCUnit = 16;
    constexpr size_t packCUnitLog = 4;
    constexpr int sparseBlockOC = 4;
    // MNN_PRINT("eSize:%zu, eP:%zu, h:%zu, l:%zu\n", eSize, eP, h, l);
    // if (eSize == eP && (h % sparseBlockOC == 0)) {
    //     _AVX512_MNNPackedSparseMatMulEpx4_ASM(C, A, B, eSize, parameter, postParameters, bias, NNZMap, dataOffsetMap);
    //     return;
    // }

    __m512 vmin = _mm512_set1_ps(*(postParameters + 2));
    __m512 vmax = _mm512_set1_ps(*(postParameters + 3));
    // MNN_PRINT("begin caculate, eSize:%ld\n", eSize);
    const float* a = A;
    size_t ie = 0;
    for (ie = 0; ie + eP <= eSize; ie += eP) { // ep: 48
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m512 vacc0, vacc1, vacc2, vacc3, vacc4, vacc5, vacc6, vacc7, vacc8, vacc9, vacc10, vacc11;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm512_set1_ps(bias[ih]);
               vacc3 = _mm512_set1_ps(bias[ih + 1]);
               vacc6 = _mm512_set1_ps(bias[ih + 2]);
               vacc9 = _mm512_set1_ps(bias[ih + 3]);
            } else {
                vacc0 = _mm512_setzero_ps();
                vacc3 = _mm512_setzero_ps();
                vacc6 = _mm512_setzero_ps();
                vacc9 = _mm512_setzero_ps();
            }
            vacc1 = vacc0;
            vacc2 = vacc0;
            vacc4 = vacc3;
            vacc5 = vacc3;
            vacc7 = vacc6;
            vacc8 = vacc6;
            vacc10 = vacc9;
            vacc11 = vacc9;
            unsigned int lElement = *nnz++;

            __m512 va0_15_swap = _mm512_loadu_ps(a);
            __m512 va16_31_swap = _mm512_loadu_ps(a + 16);
            __m512 va32_48_swap = _mm512_loadu_ps(a + 32);
            const int diff = *dataOffset++;
            a = a + diff;

            // __m512 w0_swap = _mm512_set1_ps(*(w)); // donot work. should try 2-way segement iteration
            // __m512 w1_swap = _mm512_set1_ps(*(w + 1));
            // __m512 w2_swap = _mm512_set1_ps(*(w + 2));
            // __m512 w3_swap = _mm512_set1_ps(*(w + 3));

            for (auto il = 0; il < lElement; il++) {
              // __m512 va0_15_ = _mm512_loadu_ps(a);
              // __m512 va16_31_ = _mm512_loadu_ps(a + 16);
              // __m512 va32_48_ = _mm512_loadu_ps(a + 32);
                __m512 va0_15 = va0_15_swap;
                __m512 va16_31 = va16_31_swap;
                __m512 va32_48 = va32_48_swap;

                __m512 w0 = _mm512_set1_ps(*(w));
                __m512 w1 = _mm512_set1_ps(*(w + 1));
                __m512 w2 = _mm512_set1_ps(*(w + 2));
                __m512 w3 = _mm512_set1_ps(*(w + 3));
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
                vacc1 =  _mm512_fmadd_ps(va16_31, w0, vacc1);
                vacc2 =  _mm512_fmadd_ps(va32_48, w0, vacc2);
                va0_15_swap = _mm512_loadu_ps(a);
                va16_31_swap = _mm512_loadu_ps(a + 16);
                va32_48_swap = _mm512_loadu_ps(a + 32);
                vacc3 =  _mm512_fmadd_ps(va0_15, w1, vacc3);
                vacc4 =  _mm512_fmadd_ps(va16_31, w1, vacc4);
                vacc5 =  _mm512_fmadd_ps(va32_48, w1, vacc5);

                const int diff = *dataOffset++;
                a = a + diff;

                vacc6 =  _mm512_fmadd_ps(va0_15, w2, vacc6);
                vacc7 =  _mm512_fmadd_ps(va16_31, w2, vacc7);
                vacc8 =  _mm512_fmadd_ps(va32_48, w2, vacc8);
                vacc9 =  _mm512_fmadd_ps(va0_15, w3, vacc9);
                vacc10 =  _mm512_fmadd_ps(va16_31, w3, vacc10);
                vacc11 =  _mm512_fmadd_ps(va32_48, w3, vacc11);

            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc1 = _mm512_min_ps(vacc1, vmax);
            vacc2 = _mm512_min_ps(vacc2, vmax);
            vacc3 = _mm512_min_ps(vacc3, vmax);
            vacc4 = _mm512_min_ps(vacc4, vmax);
            vacc5 = _mm512_min_ps(vacc5, vmax);
            vacc6 = _mm512_min_ps(vacc6, vmax);
            vacc7 = _mm512_min_ps(vacc7, vmax);
            vacc8 = _mm512_min_ps(vacc8, vmax);
            vacc9 = _mm512_min_ps(vacc9, vmax);
            vacc10 = _mm512_min_ps(vacc10, vmax);
            vacc11 = _mm512_min_ps(vacc11, vmax);

            vacc0 = _mm512_max_ps(vacc0, vmin);
            vacc1 = _mm512_max_ps(vacc1, vmin);
            vacc2 = _mm512_max_ps(vacc2, vmin);
            vacc3 = _mm512_max_ps(vacc3, vmin);
            vacc4 = _mm512_max_ps(vacc4, vmin);
            vacc5 = _mm512_max_ps(vacc5, vmin);
            vacc6 = _mm512_max_ps(vacc6, vmin);
            vacc7 = _mm512_max_ps(vacc7, vmin);
            vacc8 = _mm512_max_ps(vacc8, vmin);
            vacc9 = _mm512_max_ps(vacc9, vmin);
            vacc10 = _mm512_max_ps(vacc10, vmin);
            vacc11 = _mm512_max_ps(vacc11, vmin);


            TRANSPOSE4x4_STORE(c, 0, 0, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 1, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 2, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 3, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 1, 0, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 1, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 2, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 3, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 2, 0, packCUnit, vacc2, vacc5, vacc8, vacc11);
            TRANSPOSE4x4_STORE(c, 2, 1, packCUnit, vacc2, vacc5, vacc8, vacc11);
            TRANSPOSE4x4_STORE(c, 2, 2, packCUnit, vacc2, vacc5, vacc8, vacc11);
            TRANSPOSE4x4_STORE(c, 2, 3, packCUnit, vacc2, vacc5, vacc8, vacc11);
        }
        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m512 vacc0 =  nullptr != bias ? _mm512_set1_ps(*(bias + ih)) : _mm512_setzero_ps();
            __m512 vacc1 = vacc0;
            __m512 vacc2 = vacc0;
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m512 va0_15 = _mm512_loadu_ps(a);
                __m512 va16_31 = _mm512_loadu_ps(a + 16);
                __m512 va32_48 = _mm512_loadu_ps(a + 32);
                __m512 w0 = _mm512_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
                vacc1 =  _mm512_fmadd_ps(va16_31, w0, vacc1);
                vacc2 =  _mm512_fmadd_ps(va32_48, w0, vacc2);
            }
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc1 = _mm512_min_ps(vacc1, vmax);
            vacc2 = _mm512_min_ps(vacc2, vmax);
            vacc0 = _mm512_max_ps(vacc0, vmin);
            vacc1 = _mm512_max_ps(vacc1, vmin);
            vacc2 = _mm512_max_ps(vacc2, vmin);

            // how to store faster: st4 / transpose
            STORE_VECTOR_AS_COLUMN(c, 0, packCUnit, vacc0);
            STORE_VECTOR_AS_COLUMN(c, 1, packCUnit, vacc1);
            STORE_VECTOR_AS_COLUMN(c, 2, packCUnit, vacc2);
        }
        a += aStride;
    }
    auto taileSize = eSize % eP;
    if (taileSize & 0x20) { // tail eSize bitmask 32
        // MNN_PRINT("caculate 32\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m512 vacc0, vacc1, vacc3, vacc4, vacc6, vacc7, vacc9, vacc10;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm512_set1_ps(bias[ih]);
               vacc3 = _mm512_set1_ps(bias[ih + 1]);
               vacc6 = _mm512_set1_ps(bias[ih + 2]);
               vacc9 = _mm512_set1_ps(bias[ih + 3]);
            } else {
                vacc0 = _mm512_setzero_ps();
                vacc3 = _mm512_setzero_ps();
                vacc6 = _mm512_setzero_ps();
                vacc9 = _mm512_setzero_ps();
            }
            vacc1 = vacc0;
            vacc4 = vacc3;
            vacc7 = vacc6;
            vacc10 = vacc9;
            unsigned int lElement = *nnz++;

            __m512 va0_15_swap = _mm512_loadu_ps(a);
            __m512 va16_31_swap = _mm512_loadu_ps(a + 16);
            const int diff = *dataOffset++;
            a = a + diff;

            // __m512 w0_swap = _mm512_set1_ps(*(w)); // donot work. should try 2-way segement iteration
            // __m512 w1_swap = _mm512_set1_ps(*(w + 1));
            // __m512 w2_swap = _mm512_set1_ps(*(w + 2));
            // __m512 w3_swap = _mm512_set1_ps(*(w + 3));

            for (auto il = 0; il < lElement; il++) {
              // __m512 va0_15_ = _mm512_loadu_ps(a);
              // __m512 va16_31_ = _mm512_loadu_ps(a + 16);
                __m512 va0_15 = va0_15_swap;
                __m512 va16_31 = va16_31_swap;

                __m512 w0 = _mm512_set1_ps(*(w));
                __m512 w1 = _mm512_set1_ps(*(w + 1));
                __m512 w2 = _mm512_set1_ps(*(w + 2));
                __m512 w3 = _mm512_set1_ps(*(w + 3));
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
                vacc1 =  _mm512_fmadd_ps(va16_31, w0, vacc1);
                va0_15_swap = _mm512_loadu_ps(a);
                va16_31_swap = _mm512_loadu_ps(a + 16);
                vacc3 =  _mm512_fmadd_ps(va0_15, w1, vacc3);
                vacc4 =  _mm512_fmadd_ps(va16_31, w1, vacc4);

                const int diff = *dataOffset++;
                a = a + diff;

                vacc6 =  _mm512_fmadd_ps(va0_15, w2, vacc6);
                vacc7 =  _mm512_fmadd_ps(va16_31, w2, vacc7);
                vacc9 =  _mm512_fmadd_ps(va0_15, w3, vacc9);
                vacc10 =  _mm512_fmadd_ps(va16_31, w3, vacc10);

            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc1 = _mm512_min_ps(vacc1, vmax);
            vacc3 = _mm512_min_ps(vacc3, vmax);
            vacc4 = _mm512_min_ps(vacc4, vmax);
            vacc6 = _mm512_min_ps(vacc6, vmax);
            vacc7 = _mm512_min_ps(vacc7, vmax);
            vacc9 = _mm512_min_ps(vacc9, vmax);
            vacc10 = _mm512_min_ps(vacc10, vmax);

            vacc0 = _mm512_max_ps(vacc0, vmin);
            vacc1 = _mm512_max_ps(vacc1, vmin);
            vacc3 = _mm512_max_ps(vacc3, vmin);
            vacc4 = _mm512_max_ps(vacc4, vmin);
            vacc6 = _mm512_max_ps(vacc6, vmin);
            vacc7 = _mm512_max_ps(vacc7, vmin);
            vacc9 = _mm512_max_ps(vacc9, vmin);
            vacc10 = _mm512_max_ps(vacc10, vmin);

            TRANSPOSE4x4_STORE(c, 0, 0, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 1, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 2, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 3, packCUnit, vacc0, vacc3, vacc6, vacc9);

            TRANSPOSE4x4_STORE(c, 1, 0, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 1, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 2, packCUnit, vacc1, vacc4, vacc7, vacc10);
            TRANSPOSE4x4_STORE(c, 1, 3, packCUnit, vacc1, vacc4, vacc7, vacc10);
        }
        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m512 vacc0 =  nullptr != bias ? _mm512_set1_ps(*(bias + ih)) : _mm512_setzero_ps();
            __m512 vacc1 = vacc0;
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m512 va0_15 = _mm512_loadu_ps(a);
                __m512 va16_31 = _mm512_loadu_ps(a + 16);
                __m512 w0 = _mm512_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
                vacc1 =  _mm512_fmadd_ps(va16_31, w0, vacc1);
            }
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc1 = _mm512_min_ps(vacc1, vmax);
            vacc0 = _mm512_max_ps(vacc0, vmin);
            vacc1 = _mm512_max_ps(vacc1, vmin);

            // how to store faster: st4 / transpose
            STORE_VECTOR_AS_COLUMN(c, 0, packCUnit, vacc0);
            STORE_VECTOR_AS_COLUMN(c, 1, packCUnit, vacc1);
        }
        ie += 32;
        a += 32;
    }
    if (taileSize & 0x10) { // tail eSize bitmask 16
        // MNN_PRINT("caculate 16\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m512 vacc0, vacc3, vacc6, vacc9;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm512_set1_ps(bias[ih]);
               vacc3 = _mm512_set1_ps(bias[ih + 1]);
               vacc6 = _mm512_set1_ps(bias[ih + 2]);
               vacc9 = _mm512_set1_ps(bias[ih + 3]);
            } else {
                vacc0 = _mm512_setzero_ps();
                vacc3 = _mm512_setzero_ps();
                vacc6 = _mm512_setzero_ps();
                vacc9 = _mm512_setzero_ps();
            }
            unsigned int lElement = *nnz++;

            __m512 va0_15_swap = _mm512_loadu_ps(a);
            const int diff = *dataOffset++;
            a = a + diff;

            // __m512 w0_swap = _mm512_set1_ps(*(w)); // donot work. should try 2-way segement iteration
            // __m512 w1_swap = _mm512_set1_ps(*(w + 1));
            // __m512 w2_swap = _mm512_set1_ps(*(w + 2));
            // __m512 w3_swap = _mm512_set1_ps(*(w + 3));

            for (auto il = 0; il < lElement; il++) {
              // __m512 va0_15_ = _mm512_loadu_ps(a);
                __m512 va0_15 = va0_15_swap;

                __m512 w0 = _mm512_set1_ps(*(w));
                __m512 w1 = _mm512_set1_ps(*(w + 1));
                __m512 w2 = _mm512_set1_ps(*(w + 2));
                __m512 w3 = _mm512_set1_ps(*(w + 3));
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
                va0_15_swap = _mm512_loadu_ps(a);
                const int diff = *dataOffset++;
                a = a + diff;
                vacc3 =  _mm512_fmadd_ps(va0_15, w1, vacc3);
                vacc6 =  _mm512_fmadd_ps(va0_15, w2, vacc6);
                vacc9 =  _mm512_fmadd_ps(va0_15, w3, vacc9);

            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc3 = _mm512_min_ps(vacc3, vmax);
            vacc6 = _mm512_min_ps(vacc6, vmax);
            vacc9 = _mm512_min_ps(vacc9, vmax);

            vacc0 = _mm512_max_ps(vacc0, vmin);
            vacc3 = _mm512_max_ps(vacc3, vmin);
            vacc6 = _mm512_max_ps(vacc6, vmin);
            vacc9 = _mm512_max_ps(vacc9, vmin);

            TRANSPOSE4x4_STORE(c, 0, 0, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 1, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 2, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE4x4_STORE(c, 0, 3, packCUnit, vacc0, vacc3, vacc6, vacc9);
        }
        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m512 vacc0 =  nullptr != bias ? _mm512_set1_ps(*(bias + ih)) : _mm512_setzero_ps();
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m512 va0_15 = _mm512_loadu_ps(a);
                __m512 w0 = _mm512_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm512_fmadd_ps(va0_15, w0, vacc0);
            }
            vacc0 = _mm512_min_ps(vacc0, vmax);
            vacc0 = _mm512_max_ps(vacc0, vmin);

            // how to store faster: st4 / transpose
            STORE_VECTOR_AS_COLUMN(c, 0, packCUnit, vacc0);
        }
        ie += 16;
        a += 16;
    }
    if (taileSize & 0x08) { // tail eSize bitmask 8
        // MNN_PRINT("caculate 8\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m256 vacc0, vacc3, vacc6, vacc9;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm256_set1_ps(bias[ih]);
               vacc3 = _mm256_set1_ps(bias[ih + 1]);
               vacc6 = _mm256_set1_ps(bias[ih + 2]);
               vacc9 = _mm256_set1_ps(bias[ih + 3]);
            } else {
                vacc0 = _mm256_setzero_ps();
                vacc3 = _mm256_setzero_ps();
                vacc6 = _mm256_setzero_ps();
                vacc9 = _mm256_setzero_ps();
            }
            unsigned int lElement = *nnz++;

            __m256 va0_15_swap = _mm256_loadu_ps(a);
            const int diff = *dataOffset++;
            a = a + diff;

            // __m256 w0_swap = _mm256_set1_ps(*(w)); // donot work. should try 2-way segement iteration
            // __m256 w1_swap = _mm256_set1_ps(*(w + 1));
            // __m256 w2_swap = _mm256_set1_ps(*(w + 2));
            // __m256 w3_swap = _mm256_set1_ps(*(w + 3));

            for (auto il = 0; il < lElement; il++) {
              // __m256 va0_15_ = _mm256_loadu_ps(a);
                __m256 va0_15 = va0_15_swap;

                __m256 w0 = _mm256_set1_ps(*(w));
                __m256 w1 = _mm256_set1_ps(*(w + 1));
                __m256 w2 = _mm256_set1_ps(*(w + 2));
                __m256 w3 = _mm256_set1_ps(*(w + 3));
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm256_fmadd_ps(va0_15, w0, vacc0);
                va0_15_swap = _mm256_loadu_ps(a);
                const int diff = *dataOffset++;
                a = a + diff;
                vacc3 =  _mm256_fmadd_ps(va0_15, w1, vacc3);
                vacc6 =  _mm256_fmadd_ps(va0_15, w2, vacc6);
                vacc9 =  _mm256_fmadd_ps(va0_15, w3, vacc9);

            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm256_min_ps(vacc0, _mm512_extractf32x8_ps(vmax, 0));
            vacc3 = _mm256_min_ps(vacc3, _mm512_extractf32x8_ps(vmax, 0));
            vacc6 = _mm256_min_ps(vacc6, _mm512_extractf32x8_ps(vmax, 0));
            vacc9 = _mm256_min_ps(vacc9, _mm512_extractf32x8_ps(vmax, 0));

            vacc0 = _mm256_max_ps(vacc0, _mm512_extractf32x8_ps(vmin, 0));
            vacc3 = _mm256_max_ps(vacc3, _mm512_extractf32x8_ps(vmin, 0));
            vacc6 = _mm256_max_ps(vacc6, _mm512_extractf32x8_ps(vmin, 0));
            vacc9 = _mm256_max_ps(vacc9, _mm512_extractf32x8_ps(vmin, 0));

            TRANSPOSE_M256_4x4_STORE(c, 0, packCUnit, vacc0, vacc3, vacc6, vacc9);
            TRANSPOSE_M256_4x4_STORE(c, 1, packCUnit, vacc0, vacc3, vacc6, vacc9);

        }
        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m256 vacc0 =  nullptr != bias ? _mm256_set1_ps(*(bias + ih)) : _mm256_setzero_ps();
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m256 va0_15 = _mm256_loadu_ps(a);
                __m256 w0 = _mm256_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm256_fmadd_ps(va0_15, w0, vacc0);
            }
            vacc0 = _mm256_min_ps(vacc0, _mm512_extractf32x8_ps(vmax, 0));
            vacc0 = _mm256_max_ps(vacc0, _mm512_extractf32x8_ps(vmin, 0));

            // how to store faster: st4 / transpose
            STORE_M256_VECTOR_AS_COLUMN(c, packCUnit, vacc0);
        }
        ie += 8;
        a += 8;
    }
    if (taileSize & 0x04) { // tail eSize bitmask 4
        // MNN_PRINT("caculate 8\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m128 vacc0, vacc3, vacc6, vacc9;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm_set1_ps(bias[ih]);
               vacc3 = _mm_set1_ps(bias[ih + 1]);
               vacc6 = _mm_set1_ps(bias[ih + 2]);
               vacc9 = _mm_set1_ps(bias[ih + 3]);
            } else {
                vacc0 = _mm_setzero_ps();
                vacc3 = _mm_setzero_ps();
                vacc6 = _mm_setzero_ps();
                vacc9 = _mm_setzero_ps();
            }
            unsigned int lElement = *nnz++;

            __m128 va0_15_swap = _mm_loadu_ps(a);
            const int diff = *dataOffset++;
            a = a + diff;

            // __m128 w0_swap = _mm256_set1_ps(*(w)); // donot work. should try 2-way segement iteration
            // __m128 w1_swap = _mm256_set1_ps(*(w + 1));
            // __m128 w2_swap = _mm256_set1_ps(*(w + 2));
            // __m128 w3_swap = _mm256_set1_ps(*(w + 3));

            for (auto il = 0; il < lElement; il++) {
              // __m128 va0_15_ = _mm256_loadu_ps(a);
                __m128 va0_15 = va0_15_swap;

                __m128 w0 = _mm_set1_ps(*(w));
                __m128 w1 = _mm_set1_ps(*(w + 1));
                __m128 w2 = _mm_set1_ps(*(w + 2));
                __m128 w3 = _mm_set1_ps(*(w + 3));
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm_fmadd_ps(va0_15, w0, vacc0);
                va0_15_swap = _mm_loadu_ps(a);
                const int diff = *dataOffset++;
                a = a + diff;
                vacc3 =  _mm_fmadd_ps(va0_15, w1, vacc3);
                vacc6 =  _mm_fmadd_ps(va0_15, w2, vacc6);
                vacc9 =  _mm_fmadd_ps(va0_15, w3, vacc9);

            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm_min_ps(vacc0, _mm512_extractf32x4_ps(vmax, 0));
            vacc3 = _mm_min_ps(vacc3, _mm512_extractf32x4_ps(vmax, 0));
            vacc6 = _mm_min_ps(vacc6, _mm512_extractf32x4_ps(vmax, 0));
            vacc9 = _mm_min_ps(vacc9, _mm512_extractf32x4_ps(vmax, 0));

            vacc0 = _mm_max_ps(vacc0, _mm512_extractf32x4_ps(vmin, 0));
            vacc3 = _mm_max_ps(vacc3, _mm512_extractf32x4_ps(vmin, 0));
            vacc6 = _mm_max_ps(vacc6, _mm512_extractf32x4_ps(vmin, 0));
            vacc9 = _mm_max_ps(vacc9, _mm512_extractf32x4_ps(vmin, 0));

            _MM_TRANSPOSE4_PS(vacc0, vacc3, vacc6, vacc9);
            _mm_storeu_ps(c, vacc0);
            _mm_storeu_ps(c + packCUnit, vacc3);
            _mm_storeu_ps(c + packCUnit * 2, vacc6);
            _mm_storeu_ps(c + packCUnit * 3, vacc9);
        }

        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m128 vacc0 =  nullptr != bias ? _mm_set1_ps(*(bias + ih)) : _mm_setzero_ps();
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m128 va0_15 = _mm_loadu_ps(a);
                __m128 w0 = _mm_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm_fmadd_ps(va0_15, w0, vacc0);
            }
            vacc0 = _mm_min_ps(vacc0, _mm512_extractf32x4_ps(vmax, 0));
            vacc0 = _mm_max_ps(vacc0, _mm512_extractf32x4_ps(vmin, 0));

            union {
                __m128 v;
                float f[4];
            } vacc0_u;
            vacc0_u.v = vacc0;
            c[0] = vacc0_u.f[0];
            c[packCUnit] = vacc0_u.f[1];
            c[packCUnit * 2] = vacc0_u.f[2];
            c[+packCUnit * 3] = vacc0_u.f[3];
        }
        ie += 4;
        a += 4;
    }
    if (taileSize & 0x02) { // tail eSize bitmask 2
        // MNN_PRINT("caculate 8\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m128 vacc0, vacc1;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm_loadu_ps(bias + ih);
            } else {
                vacc0 = _mm_setzero_ps();
            }
            vacc1 = vacc0;
            unsigned int lElement = *nnz++;

            __m128 va0_swap = _mm_set1_ps(*(a));
            __m128 va1_swap = _mm_set1_ps(*(a + 1));
            const int diff = *dataOffset++;
            a = a + diff;

            for (auto il = 0; il < lElement; il++) {
              // __m128 va0_15_ = _mm256_loadu_ps(a);
                __m128 va0 = va0_swap;
                __m128 va1 = va1_swap;

                __m128 w0_4 = _mm_loadu_ps(w);
                w += sparseBlockOC;
                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");

                vacc0 =  _mm_fmadd_ps(va0, w0_4, vacc0);
                va0_swap = _mm_set1_ps(*(a));
                va1_swap = _mm_set1_ps(*(a + 1));
                const int diff = *dataOffset++;
                a = a + diff;
                vacc1 = _mm_fmadd_ps(va1, w0_4, vacc1);
            }
            dataOffset--;
            a = a - (*dataOffset);
            vacc0 = _mm_min_ps(vacc0, _mm512_extractf32x4_ps(vmax, 0));
            vacc1 = _mm_min_ps(vacc1, _mm512_extractf32x4_ps(vmax, 0));

            vacc0 = _mm_max_ps(vacc0, _mm512_extractf32x4_ps(vmin, 0));
            vacc1 = _mm_max_ps(vacc1, _mm512_extractf32x4_ps(vmin, 0));

            // transpose is omitted
            _mm_storeu_ps(c, vacc0);
            _mm_storeu_ps(c + packCUnit, vacc1);
        }

        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            __m128 vacc0 =  nullptr != bias ? _mm_set1_ps(*(bias + ih)) : _mm_setzero_ps();
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                __m128 va0_15 = _mm_loadu_ps(a);
                __m128 w0 = _mm_set1_ps(*w);

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B, c - C, *w);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                w++;
                a = a + diff;
                vacc0 =  _mm_fmadd_ps(va0_15, w0, vacc0);
            }
            vacc0 = _mm_min_ps(vacc0, _mm512_extractf32x4_ps(vmax, 0));
            vacc0 = _mm_max_ps(vacc0, _mm512_extractf32x4_ps(vmin, 0));

            union {
                __m128 v;
                float f[4];
            } vacc0_u;
            vacc0_u.v = vacc0;
            c[0] = vacc0_u.f[0];
            c[packCUnit] = vacc0_u.f[1];
        }
        ie += 2;
        a += 2;
    }
    if (taileSize & 0x01) { // tail eSize bitmask 1
        // MNN_PRINT("caculate 8\n");
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << packCUnitLog);
        const unsigned int* nnz = NNZMap;

        size_t ih = 0;
        for (; ih < (h & (~(sparseBlockOC - 1))); ih += sparseBlockOC) {
            auto c = blockC + (ih >> packCUnitLog) * cStride + (ih % packCUnit);

            __m128 vacc0;
            // tobe merged in to weight data
            if (bias) {
               vacc0 = _mm_loadu_ps(bias + ih);
            } else {
                vacc0 = _mm_setzero_ps();
            }
            unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
              // __m128 va0_15_ = _mm256_loadu_ps(a);
                __m128 va0 = _mm_set1_ps(*(a));
                __m128 w0_4 = _mm_loadu_ps(w);
                w += sparseBlockOC;
                vacc0 =  _mm_fmadd_ps(va0, w0_4, vacc0);
                const int diff = *dataOffset++;
                a = a + diff;

            }
            vacc0 = _mm_min_ps(vacc0, _mm512_extractf32x4_ps(vmax, 0));
            vacc0 = _mm_max_ps(vacc0, _mm512_extractf32x4_ps(vmin, 0));

            // transpose is omitted
            _mm_storeu_ps(c, vacc0);
        }
        blockC += (h >> packCUnitLog) * cStride;
        for (; ih < h; ih++) {
            auto c = blockC + ih % packCUnit;
            float acc0 =  nullptr != bias ? *(bias + ih) : 0;
            const unsigned int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                acc0 += (*a) * (*w);
                w++;
                a = a + diff;
            }
            float minValue = *(postParameters + 2);
            float maxValue = *(postParameters + 3);
            acc0 =  std::max(std::min(maxValue, acc0), minValue);
            c[0] = acc0;
        }
    }

    return;
}