void _AVX_MNNGemmInt8AddBiasScale_16x4_w4()

in source/backend/cpu/x86_x64/avx/GemmInt8.cpp [59:558]


void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
    MNN_ASSERT(post->useInt8==0);
    const auto dst_step_tmp = dst_step / sizeof(int8_t);
    auto zero128 = _mm256_set1_ps(0.0f);
    auto minValue = _mm256_set1_ps(post->minValue);
    auto maxValue = _mm256_set1_ps(post->maxValue);
    auto offset = _mm256_set1_epi32(128);
    __m256 fp32min, fp32max;
    if (post->fp32minmax) {
        fp32min = _mm256_set1_ps((post->fp32minmax)[0]);
        fp32max = _mm256_set1_ps((post->fp32minmax)[1]);
    }
    const float* biasPtr = nullptr;
    int inputBlockNum = 1;
    if (post->biasFloat) {
        biasPtr = post->biasFloat;
    }
    auto accumbuff = post->accumBuffer;
    auto blockNum = post->blockNum;
    if (post->inputBias) {
        inputBlockNum = blockNum;
    }

    int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) / 2;
    int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float)* GEMMINT8_AVX2_H;
    const __m128i mask = _mm_set1_epi8(0xf);
    
    auto srcKernelSumPtr = post->srcKernelSum;
    __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3;
    auto neg128_f   = _mm256_set1_ps(-128.f);
    __m256 extrascale0 = _mm256_setzero_ps();
    __m256 extrascale1 = _mm256_setzero_ps();
    __m256 extrascale2 = _mm256_setzero_ps();
    __m256 extrascale3 = _mm256_setzero_ps();
    __m256 extrabias0 = _mm256_setzero_ps();
    __m256 extrabias1 = _mm256_setzero_ps();
    __m256 extrabias2 = _mm256_setzero_ps();
    __m256 extrabias3 = _mm256_setzero_ps();
    if (post->inputScale) {
        if (GEMMINT8_AVX2_E == realDst) {
            extrascale0 = _mm256_set1_ps(post->inputScale[0]);
            extrascale1 = _mm256_set1_ps(post->inputScale[1]);
            extrascale2 = _mm256_set1_ps(post->inputScale[2]);
            extrascale3 = _mm256_set1_ps(post->inputScale[3]);
        } else {
            extrascale0 = _mm256_set1_ps(post->inputScale[0]);
            if (realDst > 1) {
                extrascale1 = _mm256_set1_ps(post->inputScale[1]);
            }
            if (realDst > 2) {
                extrascale2 = _mm256_set1_ps(post->inputScale[2]);
            }
        }
    }
    auto oneValue = _mm256_set1_epi16(1);
    __m256 bias0, bias1, bias2, bias3;
    // weight&scale&bias: [oc/hp, blocknum, weight_step_Z]
    // weight_step_Z: [(kx*ky), ic/lp/blocknum, hp, lp] + [hp] + [hp]
    // input: [blocknum, blockLu, EP, LP]
    if (GEMMINT8_AVX2_E == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;

            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
                kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);
                kernelSum3 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[3]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);
                __m256i D02 = _mm256_set1_epi32(0);
                __m256i D03 = _mm256_set1_epi32(0);
                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    LOAD_INT4_TO_INT8;
                    auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1);
                    auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
                    auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
                    auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3));

                    D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                    D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
                    D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
                    D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue));
                }
                auto D0 = D00;
                auto D1 = D01;
                auto D2 = D02;
                auto D3 = D03;
                auto scaleValue = _mm256_loadu_ps(scale_dz);         
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                auto f2 = _mm256_cvtepi32_ps(D2);
                auto f3 = _mm256_cvtepi32_ps(D3);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
                auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                f2 = _mm256_mul_ps(f2, scaleValue);
                f3 = _mm256_mul_ps(f3, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                        extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
                        extrascale3 = _mm256_set1_ps((post->inputScale + bk * realDst)[3]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    f2 = _mm256_mul_ps(f2, extrascale2);
                    f3 = _mm256_mul_ps(f3, extrascale3);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
                            extrabias3 = _mm256_set1_ps((post->inputBias + bk * realDst)[3]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                            bias2 = _mm256_mul_ps(extrabias2, wsum);
                            bias3 = _mm256_mul_ps(extrabias3, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                            bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
                            bias3 = _mm256_mul_ps(_mm256_mul_ps(extrascale3, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                        f2 = _mm256_add_ps(f2, bias2);
                        f3 = _mm256_add_ps(f3, bias3);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);
                f2 = _mm256_add_ps(f2, xy0_2);
                f3 = _mm256_add_ps(f3, xy0_3);

                if (bk > 0) {
                    auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                    auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                    auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
                    auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8);
                    f0 = _mm256_add_ps(f0, dstv0);
                    f1 = _mm256_add_ps(f1, dstv1);
                    f2 = _mm256_add_ps(f2, dstv2);
                    f3 = _mm256_add_ps(f3, dstv3);
                }
                if (bk == blockNum - 1) {
                    if (biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                        f2 = _mm256_add_ps(f2, biasValue);
                        f3 = _mm256_add_ps(f3, biasValue);
                    }
                    if (post->fp32minmax) {
                        f0 = _mm256_min_ps(f0, fp32max);
                        f1 = _mm256_min_ps(f1, fp32max);
                        f2 = _mm256_min_ps(f2, fp32max);
                        f3 = _mm256_min_ps(f3, fp32max);
                        f0 = _mm256_max_ps(f0, fp32min);
                        f1 = _mm256_max_ps(f1, fp32min);
                        f2 = _mm256_max_ps(f2, fp32min);
                        f3 = _mm256_max_ps(f3, fp32min);
                    }
                    _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                    _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                    _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
                    _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
                } else {
                    _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                    _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                    _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
                    _mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3);
                }
            }
        }
        return;
    }
    if (3 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;

            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
                kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);
                __m256i D02 = _mm256_set1_epi32(0);
                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    LOAD_INT4_TO_INT8;

                    auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1);
                    auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
                    auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));

                    D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                    D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
                    D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
                }
                auto D0 = D00;
                auto D1 = D01;
                auto D2 = D02;
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                auto f2 = _mm256_cvtepi32_ps(D2);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                f2 = _mm256_mul_ps(f2, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                        extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    f2 = _mm256_mul_ps(f2, extrascale2);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                            bias2 = _mm256_mul_ps(extrabias2, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz *GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                            bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                        f2 = _mm256_add_ps(f2, bias2);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);
                f2 = _mm256_add_ps(f2, xy0_2);

                if (bk > 0) {
                    auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                    auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                    auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
                    f0 = _mm256_add_ps(f0, dstv0);
                    f1 = _mm256_add_ps(f1, dstv1);
                    f2 = _mm256_add_ps(f2, dstv2);
                }
                if (bk == blockNum - 1) {
                    if (nullptr != biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                        f2 = _mm256_add_ps(f2, biasValue);
                    }
                    if (post->fp32minmax) {
                        f0 = _mm256_min_ps(f0, fp32max);
                        f1 = _mm256_min_ps(f1, fp32max);
                        f2 = _mm256_min_ps(f2, fp32max);
                        f0 = _mm256_max_ps(f0, fp32min);
                        f1 = _mm256_max_ps(f1, fp32min);
                        f2 = _mm256_max_ps(f2, fp32min);
                    }
                    _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                    _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                    _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
                } else {
                    _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                    _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                    _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
                }
            }
        }
        return;
    }    
    if (2 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;
            
            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    LOAD_INT4_TO_INT8;
                    auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1);
                    auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));

                    D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                    D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
                }
                auto D0 = D00;
                auto D1 = D01;
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);

                if (bk > 0) {
                    auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                    auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                    f0         = _mm256_add_ps(f0, dstv0);
                    f1         = _mm256_add_ps(f1, dstv1);
                }
                if (bk == blockNum - 1) {
                    if (nullptr != biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                    } 
                    if (post->fp32minmax) {
                        f0 = _mm256_min_ps(f0, fp32max);
                        f1 = _mm256_min_ps(f1, fp32max);
                        f0 = _mm256_max_ps(f0, fp32min);
                        f1 = _mm256_max_ps(f1, fp32min);
                    }
                        _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                } else {
                    _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                    _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                }
            }
        }
        return;
    }    
    if (1 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;

            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                // source kernel sum
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);

                __m256i D00 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    LOAD_INT4_TO_INT8;
                    auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1);
                    auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));

                    D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                }
                auto D0 = D00;
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                f0 = _mm256_mul_ps(f0, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);

                if (bk > 0) {
                    auto dstv = _mm256_loadu_ps(((float*)accum_x));
                    f0        = _mm256_add_ps(f0, dstv);
                }
                if (bk == 0) {
                    if (biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                    }
                }
                if (bk == blockNum - 1) {
                    if (post->fp32minmax) {
                        f0 = _mm256_min_ps(f0, fp32max);
                        f0 = _mm256_max_ps(f0, fp32min);
                    }
                    _mm256_storeu_ps(((float*)dst_x), f0);
                    
                } else {
                    _mm256_storeu_ps(((float*)accum_x) , f0);
                }
            }
        }
        return;
    }    

}