void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit()

in source/backend/cpu/x86_x64/avx/GemmInt8.cpp [560:1142]


void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
    const auto dst_step_tmp = dst_step / sizeof(int8_t);
    auto zero128 = _mm256_set1_ps(0.0f);
    auto minValue = _mm256_set1_ps(post->minValue);
    auto maxValue = _mm256_set1_ps(post->maxValue);
    auto plus = _mm256_set1_ps(0.5f);
    auto minus = _mm256_set1_ps(-0.5f);
    auto offset = _mm256_set1_epi32(128);
    __m256 fp32min, fp32max;
    if (0 == post->useInt8 && post->fp32minmax) {
        fp32min = _mm256_set1_ps((post->fp32minmax)[0]);
        fp32max = _mm256_set1_ps((post->fp32minmax)[1]);
    }
    const float* biasPtr = nullptr;
    if (post->biasFloat) {
        biasPtr = post->biasFloat;
    }
    int inputBlockNum = 1;
    auto accumbuff = post->accumBuffer;
    auto blockNum = post->blockNum;
    if (post->inputBias) {
        inputBlockNum = blockNum;
    }

    int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
    int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float) * GEMMINT8_AVX2_H;
    
    auto srcKernelSumPtr = post->srcKernelSum;
    __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3;
    auto neg128_f   = _mm256_set1_ps(-128.f);
    __m256 extrascale0 = _mm256_setzero_ps();
    __m256 extrascale1 = _mm256_setzero_ps();
    __m256 extrascale2 = _mm256_setzero_ps();
    __m256 extrascale3 = _mm256_setzero_ps();
    __m256 extrabias0 = _mm256_setzero_ps();
    __m256 extrabias1 = _mm256_setzero_ps();
    __m256 extrabias2 = _mm256_setzero_ps();
    __m256 extrabias3 = _mm256_setzero_ps();
    if (post->inputScale) {
        if (GEMMINT8_AVX2_E == realDst) {
            extrascale0 = _mm256_set1_ps(post->inputScale[0]);
            extrascale1 = _mm256_set1_ps(post->inputScale[1]);
            extrascale2 = _mm256_set1_ps(post->inputScale[2]);
            extrascale3 = _mm256_set1_ps(post->inputScale[3]);
        } else {
            extrascale0 = _mm256_set1_ps(post->inputScale[0]);
            if (realDst > 1) {
                extrascale1 = _mm256_set1_ps(post->inputScale[1]);
            }
            if (realDst > 2) {
                extrascale2 = _mm256_set1_ps(post->inputScale[2]);
            }
        }
    }
    __m256 bias0, bias1, bias2, bias3;
    if (GEMMINT8_AVX2_E == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;
            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
                kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);
                kernelSum3 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[3]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);
                __m256i D02 = _mm256_set1_epi32(0);
                __m256i D03 = _mm256_set1_epi32(0);
                __m256i D10 = _mm256_set1_epi32(0);
                __m256i D11 = _mm256_set1_epi32(0);
                __m256i D12 = _mm256_set1_epi32(0);
                __m256i D13 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
                    auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
                    auto W0 = _mm256_cvtepi8_epi16(w0);
                    auto W1 = _mm256_cvtepi8_epi16(w1);

                    auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
                    auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
                    auto s3 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 3));
                    auto S0 = _mm256_cvtepu8_epi16(s0);
                    auto S1 = _mm256_cvtepu8_epi16(s1);
                    auto S2 = _mm256_cvtepu8_epi16(s2);
                    auto S3 = _mm256_cvtepu8_epi16(s3);

                    COMPUTE(0, 0);
                    COMPUTE(1, 0);
                    COMPUTE(0, 1);
                    COMPUTE(1, 1);
                    COMPUTE(0, 2);
                    COMPUTE(1, 2);
                    COMPUTE(0, 3);
                    COMPUTE(1, 3);
                }
                auto D0 = NORMAL_HADD(D00, D10);
                auto D1 = NORMAL_HADD(D01, D11);
                auto D2 = NORMAL_HADD(D02, D12);
                auto D3 = NORMAL_HADD(D03, D13);
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                auto f2 = _mm256_cvtepi32_ps(D2);
                auto f3 = _mm256_cvtepi32_ps(D3);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
                auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                f2 = _mm256_mul_ps(f2, scaleValue);
                f3 = _mm256_mul_ps(f3, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                        extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
                        extrascale3 = _mm256_set1_ps((post->inputScale + bk * realDst)[3]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    f2 = _mm256_mul_ps(f2, extrascale2);
                    f3 = _mm256_mul_ps(f3, extrascale3);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
                            extrabias3 = _mm256_set1_ps((post->inputBias + bk * realDst)[3]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                            bias2 = _mm256_mul_ps(extrabias2, wsum);
                            bias3 = _mm256_mul_ps(extrabias3, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                            bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
                            bias3 = _mm256_mul_ps(_mm256_mul_ps(extrascale3, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                        f2 = _mm256_add_ps(f2, bias2);
                        f3 = _mm256_add_ps(f3, bias3);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);
                f2 = _mm256_add_ps(f2, xy0_2);
                f3 = _mm256_add_ps(f3, xy0_3);
                if (post->useInt8 == 1) {
                    if (biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                        f2 = _mm256_add_ps(f2, biasValue);
                        f3 = _mm256_add_ps(f3, biasValue);
                    }
                    POSTTREAT(0);
                    POSTTREAT(1);
                    POSTTREAT(2);
                    POSTTREAT(3);
                } else {
                    if (bk > 0) {
                        auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                        auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                        auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
                        auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8);
                        f0 = _mm256_add_ps(f0, dstv0);
                        f1 = _mm256_add_ps(f1, dstv1);
                        f2 = _mm256_add_ps(f2, dstv2);
                        f3 = _mm256_add_ps(f3, dstv3);
                    }
                    if (bk == blockNum - 1) {
                        if (biasPtr) {
                            const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                            auto biasValue       = _mm256_loadu_ps(bias_dz);
                            f0 = _mm256_add_ps(f0, biasValue);
                            f1 = _mm256_add_ps(f1, biasValue);
                            f2 = _mm256_add_ps(f2, biasValue);
                            f3 = _mm256_add_ps(f3, biasValue);
                        }
                        if (post->fp32minmax) {
                            f0 = _mm256_min_ps(f0, fp32max);
                            f1 = _mm256_min_ps(f1, fp32max);
                            f2 = _mm256_min_ps(f2, fp32max);
                            f3 = _mm256_min_ps(f3, fp32max);
                            f0 = _mm256_max_ps(f0, fp32min);
                            f1 = _mm256_max_ps(f1, fp32min);
                            f2 = _mm256_max_ps(f2, fp32min);
                            f3 = _mm256_max_ps(f3, fp32min);
                        }
                        _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                        _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
                        _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
                    } else {
                        _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                        _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
                        _mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3);
                    }
                }
            }
        }
        return;
    }
    if (3 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;
            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
                kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);
                __m256i D02 = _mm256_set1_epi32(0);

                __m256i D10 = _mm256_set1_epi32(0);
                __m256i D11 = _mm256_set1_epi32(0);
                __m256i D12 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
                    auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
                    auto W0 = _mm256_cvtepi8_epi16(w0);
                    auto W1 = _mm256_cvtepi8_epi16(w1);

                    auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
                    auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
                    auto S0 = _mm256_cvtepu8_epi16(s0);
                    auto S1 = _mm256_cvtepu8_epi16(s1);
                    auto S2 = _mm256_cvtepu8_epi16(s2);

                    COMPUTE(0, 0);
                    COMPUTE(1, 0);
                    COMPUTE(0, 1);
                    COMPUTE(1, 1);
                    COMPUTE(0, 2);
                    COMPUTE(1, 2);
                }
                auto D0 = NORMAL_HADD(D00, D10);
                auto D1 = NORMAL_HADD(D01, D11);
                auto D2 = NORMAL_HADD(D02, D12);
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                auto f2 = _mm256_cvtepi32_ps(D2);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                f2 = _mm256_mul_ps(f2, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                        extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    f2 = _mm256_mul_ps(f2, extrascale2);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                            bias2 = _mm256_mul_ps(extrabias2, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                            bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                        f2 = _mm256_add_ps(f2, bias2);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);
                f2 = _mm256_add_ps(f2, xy0_2);
                if (post->useInt8 == 1) {
                    if (nullptr != biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                        f2 = _mm256_add_ps(f2, biasValue);
                    }
                    POSTTREAT(0);
                    POSTTREAT(1);
                    POSTTREAT(2);
                } else {
                    if (bk > 0) {
                        auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                        auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                        auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
                        f0 = _mm256_add_ps(f0, dstv0);
                        f1 = _mm256_add_ps(f1, dstv1);
                        f2 = _mm256_add_ps(f2, dstv2);
                    }
                    if (bk == blockNum - 1) {
                        if (nullptr != biasPtr) {
                            const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                            auto biasValue       = _mm256_loadu_ps(bias_dz);
                            f0 = _mm256_add_ps(f0, biasValue);
                            f1 = _mm256_add_ps(f1, biasValue);
                            f2 = _mm256_add_ps(f2, biasValue);
                        }
                        if (post->fp32minmax) {
                            f0 = _mm256_min_ps(f0, fp32max);
                            f1 = _mm256_min_ps(f1, fp32max);
                            f2 = _mm256_min_ps(f2, fp32max);
                            f0 = _mm256_max_ps(f0, fp32min);
                            f1 = _mm256_max_ps(f1, fp32min);
                            f2 = _mm256_max_ps(f2, fp32min);
                        }
                        _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                        _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
                    } else {
                        _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                        _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
                    }
                }
            }
        }
        return;
    }    
    if (2 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;
            
            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
                kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D01 = _mm256_set1_epi32(0);

                __m256i D10 = _mm256_set1_epi32(0);
                __m256i D11 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * weight_step_Y;
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
                    auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
                    auto W0 = _mm256_cvtepi8_epi16(w0);
                    auto W1 = _mm256_cvtepi8_epi16(w1);

                    auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
                    auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
                    auto S0 = _mm256_cvtepu8_epi16(s0);
                    auto S1 = _mm256_cvtepu8_epi16(s1);

                    COMPUTE(0, 0);
                    COMPUTE(1, 0);
                    COMPUTE(0, 1);
                    COMPUTE(1, 1);
                }
                auto D0 = NORMAL_HADD(D00, D10);
                auto D1 = NORMAL_HADD(D01, D11);
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                auto f1 = _mm256_cvtepi32_ps(D1);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
                f0 = _mm256_mul_ps(f0, scaleValue);
                f1 = _mm256_mul_ps(f1, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                        extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    f1 = _mm256_mul_ps(f1, extrascale1);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                            bias1 = _mm256_mul_ps(extrabias1, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                            bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                        f1 = _mm256_add_ps(f1, bias1);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);
                f1 = _mm256_add_ps(f1, xy0_1);

                if (post->useInt8 == 1) {
                    if (nullptr != biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                        f1 = _mm256_add_ps(f1, biasValue);
                    }
                    POSTTREAT(0);
                    POSTTREAT(1);
                } else {
                    if (bk > 0) {
                        auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
                        auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
                        f0         = _mm256_add_ps(f0, dstv0);
                        f1         = _mm256_add_ps(f1, dstv1);
                    }
                    if (bk == blockNum - 1) {
                        if (nullptr != biasPtr) {
                            const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                            auto biasValue       = _mm256_loadu_ps(bias_dz);
                            f0 = _mm256_add_ps(f0, biasValue);
                            f1 = _mm256_add_ps(f1, biasValue);
                        } 
                        if (post->fp32minmax) {
                            f0 = _mm256_min_ps(f0, fp32max);
                            f1 = _mm256_min_ps(f1, fp32max);
                            f0 = _mm256_max_ps(f0, fp32min);
                            f1 = _mm256_max_ps(f1, fp32min);
                        }
                            _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                            _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                    } else {
                        _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
                        _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
                    }
                }
            }
        }
        return;
    }    
    if (1 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            auto dst_x           = dst + dz * dst_step_tmp;
            auto accum_x       = accumbuff;

            for (int bk = 0; bk < blockNum; ++bk) {
                // block's weight&scale&bias
                const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
                const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
                const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
                // block's input
                const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
                // source kernel sum
                kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);

                __m256i D00 = _mm256_set1_epi32(0);
                __m256i D10 = _mm256_set1_epi32(0);

                for (int sz = 0; sz < src_depth_quad; ++sz) {
                    const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                    const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                    auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
                    auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
                    auto W0 = _mm256_cvtepi8_epi16(w0);
                    auto W1 = _mm256_cvtepi8_epi16(w1);

                    auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
                    auto S0 = _mm256_cvtepu8_epi16(s0);

                    COMPUTE(0, 0);
                    COMPUTE(1, 0);
                }
                auto D0 = NORMAL_HADD(D00, D10);
                auto scaleValue = _mm256_loadu_ps(scale_dz);
                auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

                auto f0 = _mm256_cvtepi32_ps(D0);
                // x_kernelSum x w_quantZero
                auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
                f0 = _mm256_mul_ps(f0, scaleValue);
                if (post->inputScale) {
                    if (post->inputBias) {
                        extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
                    }
                    f0 = _mm256_mul_ps(f0, extrascale0);
                    if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                        if (post->inputBias) {
                            auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
                            bias0 = _mm256_mul_ps(extrabias0, wsum);
                        } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                            auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
                            auto wsum = _mm256_loadu_ps(wsumDz);
                            bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
                        }
                        f0 = _mm256_add_ps(f0, bias0);
                    }
                }
                f0 = _mm256_add_ps(f0, xy0_0);

                if (post->useInt8 == 1) {
                    if (nullptr != biasPtr) {
                        const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                        auto biasValue       = _mm256_loadu_ps(bias_dz);
                        f0 = _mm256_add_ps(f0, biasValue);
                    }
                    POSTTREAT(0);
                } else {
                    if (bk > 0) {
                        auto dstv = _mm256_loadu_ps(((float*)accum_x));
                        f0        = _mm256_add_ps(f0, dstv);
                    }
                    if (bk == blockNum - 1) {
                        if (biasPtr) {
                            const auto bias_dz   = biasPtr + dz * AVX2_PACKINT8;
                            auto biasValue       = _mm256_loadu_ps(bias_dz);
                            f0 = _mm256_add_ps(f0, biasValue);
                        }
                        if (post->fp32minmax) {
                            f0 = _mm256_min_ps(f0, fp32max);
                            f0 = _mm256_max_ps(f0, fp32min);
                        }
                        _mm256_storeu_ps(((float*)dst_x), f0);
                        
                    } else {
                        _mm256_storeu_ps(((float*)accum_x) , f0);
                    }
                }
            }
        }
        return;
    }    

}