void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast()

in source/backend/cpu/x86_x64/avx/GemmInt8.cpp [1143:1425]


void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
    const auto dst_step_tmp = dst_step / sizeof(int8_t);
    auto zero128 = _mm256_set1_ps(0.0f);
    auto minValue = _mm256_set1_ps(post->minValue);
    auto maxValue = _mm256_set1_ps(post->maxValue);
    auto plus = _mm256_set1_ps(0.5f);
    auto minus = _mm256_set1_ps(-0.5f);
    auto oneValue = _mm256_set1_epi16(1);
    auto offset = _mm256_set1_epi32(128);
    __m256 fp32min, fp32max;
    if (0 == post->useInt8) {
        fp32min = _mm256_set1_ps((post->fp32minmax)[0]);
        fp32max = _mm256_set1_ps((post->fp32minmax)[1]);
    }
    auto srcKernelSumPtr = post->srcKernelSum;
    __m256 kernelSum0 = _mm256_setzero_ps();
    __m256 kernelSum1 = _mm256_setzero_ps();
    __m256 kernelSum2 = _mm256_setzero_ps();
    __m256 kernelSum3 = _mm256_setzero_ps();
    if (GEMMINT8_AVX2_E == realDst) {
        kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]);
        kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]);
        kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]);
        kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]);
    } else {
        kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]);
        if (realDst > 1) {
            kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]);
        }
        if (realDst > 2) {
            kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]);
        }
    }
    int weight_step_Z = src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H;
    int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
    if (GEMMINT8_AVX2_E == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            const auto weight_dz = weight + dz * weight_step_Z;
            const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
            const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
            auto dst_z           = dst + dz * dst_step_tmp;
            const auto src_x   = src;
            auto dst_x         = dst_z;
            __m256i D00 = _mm256_set1_epi32(0);
            __m256i D01 = _mm256_set1_epi32(0);
            __m256i D02 = _mm256_set1_epi32(0);
            __m256i D03 = _mm256_set1_epi32(0);

            for (int sz = 0; sz < src_depth_quad; ++sz) {
                const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);

                auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
                auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
                auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3));

                D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
                D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
                D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue));

            }
            auto D0 = D00;
            auto D1 = D01;
            auto D2 = D02;
            auto D3 = D03;

            auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

            auto scaleValue = _mm256_loadu_ps(scale_dz);
            auto f0 = _mm256_cvtepi32_ps(D0);
            auto f1 = _mm256_cvtepi32_ps(D1);
            auto f2 = _mm256_cvtepi32_ps(D2);
            auto f3 = _mm256_cvtepi32_ps(D3);
            // x_kernelSum x w_quantZero
            auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
            auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
            auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
            auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth
            f0 = _mm256_mul_ps(f0, scaleValue);
            f1 = _mm256_mul_ps(f1, scaleValue);
            f2 = _mm256_mul_ps(f2, scaleValue);
            f3 = _mm256_mul_ps(f3, scaleValue);
            f0 = _mm256_add_ps(f0, xy0_0);
            f1 = _mm256_add_ps(f1, xy0_1);
            f2 = _mm256_add_ps(f2, xy0_2);
            f3 = _mm256_add_ps(f3, xy0_3);
            auto biasValue       = _mm256_loadu_ps(weightBias_dz);
            f0 = _mm256_add_ps(f0, biasValue);
            f1 = _mm256_add_ps(f1, biasValue);
            f2 = _mm256_add_ps(f2, biasValue);
            f3 = _mm256_add_ps(f3, biasValue);
            if (post->useInt8 == 0) {
                f0 = _mm256_min_ps(f0, fp32max);
                f1 = _mm256_min_ps(f1, fp32max);
                f2 = _mm256_min_ps(f2, fp32max);
                f3 = _mm256_min_ps(f3, fp32max);
                f0 = _mm256_max_ps(f0, fp32min);
                f1 = _mm256_max_ps(f1, fp32min);
                f2 = _mm256_max_ps(f2, fp32min);
                f3 = _mm256_max_ps(f3, fp32min);
                _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
                _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
            } else {
                POSTTREAT(0);
                POSTTREAT(1);
                POSTTREAT(2);
                POSTTREAT(3);
            }
        }
        return;
    }
    if (3 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            const auto weight_dz = weight + dz * weight_step_Z;
            const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
            const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
            auto dst_z           = dst + dz * dst_step_tmp;
            const auto src_x   = src;
            auto dst_x         = dst_z;
            __m256i D00 = _mm256_set1_epi32(0);
            __m256i D01 = _mm256_set1_epi32(0);
            __m256i D02 = _mm256_set1_epi32(0);

            for (int sz = 0; sz < src_depth_quad; ++sz) {
                const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);

                auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
                auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));

                D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
                D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
            }
            auto D0 = D00;
            auto D1 = D01;
            auto D2 = D02;

            // auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
            auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
            // D0 = _mm256_add_epi32(D0, biasValue0);
            // D1 = _mm256_add_epi32(D1, biasValue0);
            // D2 = _mm256_add_epi32(D2, biasValue0);

            auto scaleValue = _mm256_loadu_ps(scale_dz);
            
            auto f0 = _mm256_cvtepi32_ps(D0);
            auto f1 = _mm256_cvtepi32_ps(D1);
            auto f2 = _mm256_cvtepi32_ps(D2);
            // x_kernelSum x w_quantZero
            auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
            auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
            auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
            f0 = _mm256_mul_ps(f0, scaleValue);
            f1 = _mm256_mul_ps(f1, scaleValue);
            f2 = _mm256_mul_ps(f2, scaleValue);
            f0 = _mm256_add_ps(f0, xy0_0);
            f1 = _mm256_add_ps(f1, xy0_1);
            f2 = _mm256_add_ps(f2, xy0_2);
            auto biasValue       = _mm256_loadu_ps(weightBias_dz);
            f0 = _mm256_add_ps(f0, biasValue);
            f1 = _mm256_add_ps(f1, biasValue);
            f2 = _mm256_add_ps(f2, biasValue);
            if (post->useInt8 == 0) {
                f0 = _mm256_min_ps(f0, fp32max);
                f1 = _mm256_min_ps(f1, fp32max);
                f2 = _mm256_min_ps(f2, fp32max);
                f0 = _mm256_max_ps(f0, fp32min);
                f1 = _mm256_max_ps(f1, fp32min);
                f2 = _mm256_max_ps(f2, fp32min);
                _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
                _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
            } else {
                POSTTREAT(0);
                POSTTREAT(1);
                POSTTREAT(2);
            }
        }
        return;
    }    
    if (2 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            const auto weight_dz = weight + dz * weight_step_Z;
            const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
            const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
            auto dst_z           = dst + dz * dst_step_tmp;
            const auto src_x   = src;
            auto dst_x         = dst_z;
            __m256i D00 = _mm256_set1_epi32(0);
            __m256i D01 = _mm256_set1_epi32(0);

            for (int sz = 0; sz < src_depth_quad; ++sz) {
                const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);

                auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
                auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));

                D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
                D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
            }
            auto D0 = D00;
            auto D1 = D01;

            auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

            auto scaleValue = _mm256_loadu_ps(scale_dz);
            auto f0 = _mm256_cvtepi32_ps(D0);
            auto f1 = _mm256_cvtepi32_ps(D1);
            // x_kernelSum x w_quantZero
            auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
            auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
            f0 = _mm256_mul_ps(f0, scaleValue);
            f1 = _mm256_mul_ps(f1, scaleValue);
            f0 = _mm256_add_ps(f0, xy0_0);
            f1 = _mm256_add_ps(f1, xy0_1);
            auto biasValue       = _mm256_loadu_ps(weightBias_dz);
            f0 = _mm256_add_ps(f0, biasValue);
            f1 = _mm256_add_ps(f1, biasValue);
            if (post->useInt8 == 0) {
                f0 = _mm256_min_ps(f0, fp32max);
                f1 = _mm256_min_ps(f1, fp32max);
                f0 = _mm256_max_ps(f0, fp32min);
                f1 = _mm256_max_ps(f1, fp32min);
                _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
                _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
            } else {
                POSTTREAT(0);
                POSTTREAT(1);
            }
        }
        return;
    }    
    if (1 == realDst) {
        for (int dz = 0; dz < dst_depth_quad; ++dz) {
            const auto weight_dz = weight + dz * weight_step_Z;
            const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
            const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
            auto dst_z           = dst + dz * dst_step_tmp;
            const auto src_x   = src;
            auto dst_x         = dst_z;
            __m256i D00 = _mm256_set1_epi32(0);

            for (int sz = 0; sz < src_depth_quad; ++sz) {
                const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
                const auto src_z     = src_x + sz * GEMMINT8_AVX2_L * realDst;
                auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);

                auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));

                D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
            }
            auto D0 = D00;
            auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);

            auto scaleValue = _mm256_loadu_ps(scale_dz);
            auto f0 = _mm256_cvtepi32_ps(D0);
            // x_kernelSum x w_quantZero
            auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
            f0 = _mm256_mul_ps(f0, scaleValue);
            f0 = _mm256_add_ps(f0, xy0_0);
            auto biasValue       = _mm256_loadu_ps(weightBias_dz);
            f0 = _mm256_add_ps(f0, biasValue);
            if (post->useInt8 == 0) {
                f0 = _mm256_min_ps(f0, fp32max);
                f0 = _mm256_max_ps(f0, fp32min);
                _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
            } else {
                POSTTREAT(0);
            }
        }
        return;
    }
}