void _SSE_MNNGemmInt8AddBiasScale_16x4_w4()

in source/backend/cpu/x86_x64/sse/GemmInt8.cpp [349:634]


void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
                                            size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
    MNN_ASSERT(post->useInt8 == 0);
    const auto dst_step_tmp = dst_step / sizeof(int8_t);
    __m128i zero = _mm_set1_epi32(0);
    __m128 minValue = _mm_set1_ps(post->minValue);
    __m128 maxValue = _mm_set1_ps(post->maxValue);
    __m128 fp32min, fp32max;
    if (post->fp32minmax) {
        fp32min = _mm_set1_ps((post->fp32minmax)[0]);
        fp32max = _mm_set1_ps((post->fp32minmax)[1]);
    }
    const float* biasPtr = nullptr;
    if (post->biasFloat) {
        biasPtr = post->biasFloat;
    }
    auto accumbuff = post->accumBuffer;
    auto blockNum = post->blockNum;
    int weight_step_Z = 0.5 * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT;
    int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);

    auto oneValue = _mm_set1_epi16(1);
    auto offset = _mm_set1_epi32(128);
    auto neg128f   = _mm_set1_ps(-128.f);
    auto srcKernelSumPtr = post->srcKernelSum;
    __m128 kernelSum0 = _mm_setzero_ps();
    __m128 kernelSum1 = _mm_setzero_ps();
    __m128 kernelSum2 = _mm_setzero_ps();
    __m128 kernelSum3 = _mm_setzero_ps();
    __m128 extrabias0 = _mm_setzero_ps();
    __m128 extrabias1 = _mm_setzero_ps();
    __m128 extrabias2 = _mm_setzero_ps();
    __m128 extrabias3 = _mm_setzero_ps();
    const auto mask = _mm_set1_epi8(0xf);
    if (GEMM_INT8_DST_XUNIT == realDst) {
        kernelSum0 = _mm_load_ps1(post->srcKernelSum);
        kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1);
        kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2);
        kernelSum3 = _mm_load_ps1(post->srcKernelSum + 3);
    } else {
        kernelSum0 = _mm_load_ps1(post->srcKernelSum);
        if (realDst > 1) {
            kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1);
        }
        if (realDst > 2) {
            kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2);
        }
    }
    auto f128 = _mm_set1_ps(128.f);
    __m128 extrascale0 = _mm_setzero_ps();
    __m128 extrascale1 = _mm_setzero_ps();
    __m128 extrascale2 = _mm_setzero_ps();
    __m128 extrascale3 = _mm_setzero_ps();
    if (post->inputScale) {
        if (GEMM_INT8_DST_XUNIT == realDst) {
            extrascale0 = _mm_load_ps1(post->inputScale);
            extrascale1 = _mm_load_ps1(post->inputScale + 1);
            extrascale2 = _mm_load_ps1(post->inputScale + 2);
            extrascale3 = _mm_load_ps1(post->inputScale + 3);
        } else {
            extrascale0 = _mm_load_ps1(post->inputScale);
            if (realDst > 1) {
                extrascale1 = _mm_load_ps1(post->inputScale + 1);
            }
            if (realDst > 2) {
                extrascale2 = _mm_load_ps1(post->inputScale + 2);
            }
        }
    }
    __m128 bias0, bias1, bias2, bias3;
    if (post->inputBias) {
        if (GEMM_INT8_DST_XUNIT == realDst) {
            extrabias0 = _mm_load_ps1(post->inputBias);
            extrabias1 = _mm_load_ps1(post->inputBias + 1);
            extrabias2 = _mm_load_ps1(post->inputBias + 2);
            extrabias3 = _mm_load_ps1(post->inputBias + 3);
        } else {
            extrabias0 = _mm_load_ps1(post->inputBias);
            if (realDst > 1) {
                extrabias1 = _mm_load_ps1(post->inputBias + 1);
            }
            if (realDst > 2) {
                extrabias2 = _mm_load_ps1(post->inputBias + 2);
            }
        }
    }
    for (int dz = 0; dz < dst_depth_quad; ++dz) {
        auto dst_x           = dst + dz * dst_step_tmp;
        auto accum_x         = accumbuff;
        for (int bk = 0; bk < blockNum; ++bk) {
            // block's weight&scale&bias
            const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk *  weight_step_Z;
            const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
            const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT;
            // block's input
            const auto src_x = src + bk * src_depth_quad * GEMM_INT8_SRC_UNIT * realDst;
            __m128i d0 = _mm_set1_epi32(0);
            __m128i d1 = _mm_set1_epi32(0);
            __m128i d2 = _mm_set1_epi32(0);
            __m128i d3 = _mm_set1_epi32(0);

            __m128i e0 = _mm_set1_epi32(0);
            __m128i e1 = _mm_set1_epi32(0);
            __m128i e2 = _mm_set1_epi32(0);
            __m128i e3 = _mm_set1_epi32(0);

            __m128i D0 = _mm_set1_epi32(0);
            __m128i D1 = _mm_set1_epi32(0);
            __m128i D2 = _mm_set1_epi32(0);
            __m128i D3 = _mm_set1_epi32(0);

            __m128i E0 = _mm_set1_epi32(0);
            __m128i E1 = _mm_set1_epi32(0);
            __m128i E2 = _mm_set1_epi32(0);
            __m128i E3 = _mm_set1_epi32(0);

            for (int sz = 0; sz < src_depth_quad; ++sz) {
                const auto weight_sz = weight_dz + weight_step_Y * sz;
                const auto src_z     = src_x + sz * realDst * GEMM_INT8_SRC_UNIT;

                LOAD_INT4_TO_INT8;

                auto s0 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 0));
                auto s1 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 1));
                auto s2 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 2));
                auto s3 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 3));


    //#define COMPUTE(i, j)\
    //auto d##i##j = _mm_maddubs_epi16(s##i, w##j);\
    //d##i##j = _mm_madd_epi16(d##i##j, oneValue);\

    #define COMPUTE(i, j)\
    auto W##i##j##0 = _mm_srai_epi16(_mm_unpacklo_epi8(zero, w##j), 8);\
    auto W##i##j##1 = _mm_srai_epi16(_mm_unpackhi_epi8(zero, w##j), 8);\
    auto S##i##j##0 = _mm_unpacklo_epi8(s##i, zero);\
    auto S##i##j##1 = _mm_unpackhi_epi8(s##i, zero);\
    auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_epi16(S##i##j##1, W##i##j##1));\

                COMPUTE(0, 0);
                COMPUTE(0, 1);
                COMPUTE(0, 2);
                COMPUTE(0, 3);
                COMPUTE(1, 0);
                COMPUTE(1, 1);
                COMPUTE(1, 2);
                COMPUTE(1, 3);
                COMPUTE(2, 0);
                COMPUTE(2, 1);
                COMPUTE(2, 2);
                COMPUTE(2, 3);
                COMPUTE(3, 0);
                COMPUTE(3, 1);
                COMPUTE(3, 2);
                COMPUTE(3, 3);

                d0 = _mm_add_epi32(d0, d00);
                d1 = _mm_add_epi32(d1, d01);
                d2 = _mm_add_epi32(d2, d02);
                d3 = _mm_add_epi32(d3, d03);

                e0 = _mm_add_epi32(e0, d10);
                e1 = _mm_add_epi32(e1, d11);
                e2 = _mm_add_epi32(e2, d12);
                e3 = _mm_add_epi32(e3, d13);

                D0 = _mm_add_epi32(D0, d20);
                D1 = _mm_add_epi32(D1, d21);
                D2 = _mm_add_epi32(D2, d22);
                D3 = _mm_add_epi32(D3, d23);

                E0 = _mm_add_epi32(E0, d30);
                E1 = _mm_add_epi32(E1, d31);
                E2 = _mm_add_epi32(E2, d32);
                E3 = _mm_add_epi32(E3, d33);
            }
            d0 = _mm_hadd_epi32(d0, d1);
            d1 = _mm_hadd_epi32(d2, d3);
            d0 = _mm_hadd_epi32(d0, d1);

            e0 = _mm_hadd_epi32(e0, e1);
            e1 = _mm_hadd_epi32(e2, e3);
            d1 = _mm_hadd_epi32(e0, e1);

            D0 = _mm_hadd_epi32(D0, D1);
            D1 = _mm_hadd_epi32(D2, D3);
            d2 = _mm_hadd_epi32(D0, D1);

            E0 = _mm_hadd_epi32(E0, E1);
            E1 = _mm_hadd_epi32(E2, E3);
            d3 = _mm_hadd_epi32(E0, E1);
            auto scaleValue = _mm_loadu_ps(scale_dz);
            auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz);
            __m128 f0 = _mm_cvtepi32_ps(d0);
            __m128 f1 = _mm_cvtepi32_ps(d1);
            __m128 f2 = _mm_cvtepi32_ps(d2);
            __m128 f3 = _mm_cvtepi32_ps(d3);

            kernelSum0 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[0]);
            if (realDst > 1) kernelSum1 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[1]);
            if (realDst > 2) kernelSum2 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[2]);
            if (realDst > 3) kernelSum3 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[3]);
            auto xy0_0 = _mm_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
            auto xy0_1 = _mm_mul_ps(kernelSum1, weightBiasValue); // ..second
            auto xy0_2 = _mm_mul_ps(kernelSum2, weightBiasValue); // .. third
            auto xy0_3 = _mm_mul_ps(kernelSum3, weightBiasValue); // ..fourth
            f0 = _mm_mul_ps(f0, scaleValue);
            f1 = _mm_mul_ps(f1, scaleValue);
            f2 = _mm_mul_ps(f2, scaleValue);
            f3 = _mm_mul_ps(f3, scaleValue);
            if (post->inputScale) {
                if (post->inputBias) {
                    extrascale0 = _mm_set1_ps((post->inputScale + bk * realDst)[0]);
                    if (realDst > 1) extrascale1 = _mm_set1_ps((post->inputScale + bk * realDst)[1]);
                    if (realDst > 2) extrascale2 = _mm_set1_ps((post->inputScale + bk * realDst)[2]);
                    if (realDst > 3) extrascale3 = _mm_set1_ps((post->inputScale + bk * realDst)[3]);
                }
                f0 = _mm_mul_ps(f0, extrascale0);
                f1 = _mm_mul_ps(f1, extrascale1);
                f2 = _mm_mul_ps(f2, extrascale2);
                f3 = _mm_mul_ps(f3, extrascale3);
                if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
                    if (post->inputBias) {
                        auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT;
                        auto wsum = _mm_loadu_ps(wsumDz);
                        extrabias0 = _mm_set1_ps((post->inputBias + bk * realDst)[0]);
                        if (realDst > 1) extrabias1 = _mm_set1_ps((post->inputBias + bk * realDst)[1]);
                        if (realDst > 2) extrabias2 = _mm_set1_ps((post->inputBias + bk * realDst)[2]);
                        if (realDst > 3) extrabias3 = _mm_set1_ps((post->inputBias + bk * realDst)[3]);
                        bias0 = _mm_mul_ps(extrabias0, wsum);
                        bias1 = _mm_mul_ps(extrabias1, wsum);
                        bias2 = _mm_mul_ps(extrabias2, wsum);
                        bias3 = _mm_mul_ps(extrabias3, wsum);
                    } else if (bk == blockNum - 1) { // if input not block quant, only accum once!
                        auto wsumDz = post->weightKernelSum + dz * GEMM_INT8_UNIT;
                        auto wsum = _mm_loadu_ps(wsumDz);
                        bias0 = _mm_mul_ps(_mm_mul_ps(extrascale0, neg128f), wsum);
                        bias1 = _mm_mul_ps(_mm_mul_ps(extrascale1, neg128f), wsum);
                        bias2 = _mm_mul_ps(_mm_mul_ps(extrascale2, neg128f), wsum);
                        bias3 = _mm_mul_ps(_mm_mul_ps(extrascale3, neg128f), wsum);
                    }
                    f0 = _mm_add_ps(f0, bias0);
                    f1 = _mm_add_ps(f1, bias1);
                    f2 = _mm_add_ps(f2, bias2);
                    f3 = _mm_add_ps(f3, bias3);
                }
            }
            f0 = _mm_add_ps(f0, xy0_0);
            f1 = _mm_add_ps(f1, xy0_1);
            f2 = _mm_add_ps(f2, xy0_2);
            f3 = _mm_add_ps(f3, xy0_3);

            __m128 f[4] = {f0, f1, f2, f3};
            
            if (bk > 0) {
                for (int j = 0; j < realDst; ++j) {
                    auto dstv = _mm_loadu_ps(((float*)accum_x) + j * 4);
                    f[j] = _mm_add_ps(dstv, f[j]);
                }
            }

            if (bk == blockNum - 1) {
                if (nullptr != biasPtr) {
                    const auto bias_dz   = biasPtr + dz * GEMM_INT8_UNIT;
                    auto biasValue = _mm_loadu_ps(bias_dz);
                    for (int j = 0; j < realDst; ++j) {
                        f[j] = _mm_add_ps(biasValue, f[j]);
                    }
                }
                if (post->fp32minmax) {
                    for (int j = 0; j < realDst; ++j) {
                        f[j] = _mm_min_ps(f[j], fp32max);
                        f[j] = _mm_max_ps(f[j], fp32min);
                    }
                }
                for (int j = 0; j < realDst; ++j) {
                    _mm_storeu_ps(((float*)dst_x) + j * 4, f[j]);
                }
            } else {
                for (int j = 0; j < realDst; ++j) {
                    _mm_storeu_ps(((float*)accum_x) + j * 4, f[j]);
                }
            }
        }
    }
}