source/backend/cpu/x86_x64/avx/GemmInt8.cpp (1,566 lines of code) (raw):

// // GemmInt8.cpp // MNN // // Created by MNN on 2020/09/22. // Copyright © 2018, Alibaba Group Holding Limited // #include "GemmCommon.hpp" #include "FunctionSummary.hpp" #include "core/Macro.h" #include <math.h> #define AVX2_PACKINT8 8 #define GEMMINT8_AVX2_E 4 #define GEMMINT8_AVX2_L 4 #define GEMMINT8_AVX2_H 8 namespace { static inline __m128i mm_loadu_si128(const void* addr) { return _mm_loadu_si128((__m128i const*)addr); } static inline void MNN__mm_storeu_si64(void* add, __m128i value) { float temp[4]; _mm_storeu_ps(temp, _mm_castsi128_ps(value)); ::memcpy(add, temp, sizeof(int64_t)); } } // namespace #define POSTTREAT(N) \ f##N = _mm256_min_ps(f##N, maxValue);\ f##N = _mm256_max_ps(f##N, minValue);\ auto m##N = _mm256_cmp_ps(f##N, zero128, 1);\ m##N = _mm256_blendv_ps(plus, minus, m##N);\ f##N = _mm256_add_ps(f##N, m##N);\ D##N = _mm256_cvtps_epi32(_mm256_round_ps(f##N, 3));\ D##N = _mm256_add_epi32(D##N, offset);\ D##N = _mm256_packs_epi32(D##N, _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(D##N), _mm256_castsi256_ps(D##N), 1)));\ auto d##N = _mm_packus_epi16(_mm256_castsi256_si128(D##N), _mm256_castsi256_si128(_mm256_castps_si256(zero128)));\ MNN__mm_storeu_si64(dst_x + N * 8, d##N); inline __m256i NORMAL_HADD(__m256i x, __m256i y) { auto c0 = _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(x), _mm256_castsi256_ps(y), 32)); auto c1 = _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(x), _mm256_castsi256_ps(y), 49)); return _mm256_hadd_epi32(c0, c1); } #define EXTRACT_ADD(i)\ auto d##i##0 = _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(D##i), 0));\ auto d##i##1 = _mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(D##i), 1));\ auto d##i = _mm_add_epi32(d##i##0, d##i##1); #define COMPUTE(u, v)\ D##u##v = _mm256_add_epi32(D##u##v, _mm256_madd_epi16(W##u, S##v)); #define LOAD_INT4_TO_INT8 \ auto w_int4 = _mm_loadu_si128((__m128i const*)weight_sz);\ auto w_0 = _mm_and_si128(mask, _mm_srli_epi16(w_int4, 4));\ auto w_1 = _mm_and_si128(mask, w_int4); void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { MNN_ASSERT(post->useInt8==0); const auto dst_step_tmp = dst_step / sizeof(int8_t); auto zero128 = _mm256_set1_ps(0.0f); auto minValue = _mm256_set1_ps(post->minValue); auto maxValue = _mm256_set1_ps(post->maxValue); auto offset = _mm256_set1_epi32(128); __m256 fp32min, fp32max; if (post->fp32minmax) { fp32min = _mm256_set1_ps((post->fp32minmax)[0]); fp32max = _mm256_set1_ps((post->fp32minmax)[1]); } const float* biasPtr = nullptr; int inputBlockNum = 1; if (post->biasFloat) { biasPtr = post->biasFloat; } auto accumbuff = post->accumBuffer; auto blockNum = post->blockNum; if (post->inputBias) { inputBlockNum = blockNum; } int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) / 2; int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float)* GEMMINT8_AVX2_H; const __m128i mask = _mm_set1_epi8(0xf); auto srcKernelSumPtr = post->srcKernelSum; __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3; auto neg128_f = _mm256_set1_ps(-128.f); __m256 extrascale0 = _mm256_setzero_ps(); __m256 extrascale1 = _mm256_setzero_ps(); __m256 extrascale2 = _mm256_setzero_ps(); __m256 extrascale3 = _mm256_setzero_ps(); __m256 extrabias0 = _mm256_setzero_ps(); __m256 extrabias1 = _mm256_setzero_ps(); __m256 extrabias2 = _mm256_setzero_ps(); __m256 extrabias3 = _mm256_setzero_ps(); if (post->inputScale) { if (GEMMINT8_AVX2_E == realDst) { extrascale0 = _mm256_set1_ps(post->inputScale[0]); extrascale1 = _mm256_set1_ps(post->inputScale[1]); extrascale2 = _mm256_set1_ps(post->inputScale[2]); extrascale3 = _mm256_set1_ps(post->inputScale[3]); } else { extrascale0 = _mm256_set1_ps(post->inputScale[0]); if (realDst > 1) { extrascale1 = _mm256_set1_ps(post->inputScale[1]); } if (realDst > 2) { extrascale2 = _mm256_set1_ps(post->inputScale[2]); } } } auto oneValue = _mm256_set1_epi16(1); __m256 bias0, bias1, bias2, bias3; // weight&scale&bias: [oc/hp, blocknum, weight_step_Z] // weight_step_Z: [(kx*ky), ic/lp/blocknum, hp, lp] + [hp] + [hp] // input: [blocknum, blockLu, EP, LP] if (GEMMINT8_AVX2_E == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]); kernelSum3 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[3]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); __m256i D03 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; LOAD_INT4_TO_INT8; auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2)); auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue)); D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto D2 = D02; auto D3 = D03; auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); auto f3 = _mm256_cvtepi32_ps(D3); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f3 = _mm256_mul_ps(f3, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]); extrascale3 = _mm256_set1_ps((post->inputScale + bk * realDst)[3]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); f2 = _mm256_mul_ps(f2, extrascale2); f3 = _mm256_mul_ps(f3, extrascale3); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]); extrabias3 = _mm256_set1_ps((post->inputBias + bk * realDst)[3]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); bias2 = _mm256_mul_ps(extrabias2, wsum); bias3 = _mm256_mul_ps(extrabias3, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum); bias3 = _mm256_mul_ps(_mm256_mul_ps(extrascale3, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); f2 = _mm256_add_ps(f2, bias2); f3 = _mm256_add_ps(f3, bias3); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); f3 = _mm256_add_ps(f3, xy0_3); if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8); auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); f2 = _mm256_add_ps(f2, dstv2); f3 = _mm256_add_ps(f3, dstv3); } if (bk == blockNum - 1) { if (biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); f3 = _mm256_add_ps(f3, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f3 = _mm256_min_ps(f3, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); f3 = _mm256_max_ps(f3, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2); _mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3); } } } return; } if (3 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; LOAD_INT4_TO_INT8; auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto D2 = D02; auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); f2 = _mm256_mul_ps(f2, extrascale2); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); bias2 = _mm256_mul_ps(extrabias2, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz *GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); f2 = _mm256_add_ps(f2, bias2); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); f2 = _mm256_add_ps(f2, dstv2); } if (bk == blockNum - 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2); } } } return; } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; LOAD_INT4_TO_INT8; auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); } if (bk == blockNum - 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); } } } return; } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; // source kernel sum kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); __m256i D00 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; LOAD_INT4_TO_INT8; auto w0 = _mm256_insertf128_si256(_mm256_castsi128_si256(w_0), w_1, 1); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); } auto D0 = D00; auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first f0 = _mm256_mul_ps(f0, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); } f0 = _mm256_mul_ps(f0, extrascale0); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); bias0 = _mm256_mul_ps(extrabias0, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); } } f0 = _mm256_add_ps(f0, xy0_0); if (bk > 0) { auto dstv = _mm256_loadu_ps(((float*)accum_x)); f0 = _mm256_add_ps(f0, dstv); } if (bk == 0) { if (biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); } } if (bk == blockNum - 1) { if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f0 = _mm256_max_ps(f0, fp32min); } _mm256_storeu_ps(((float*)dst_x), f0); } else { _mm256_storeu_ps(((float*)accum_x) , f0); } } } return; } } void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { const auto dst_step_tmp = dst_step / sizeof(int8_t); auto zero128 = _mm256_set1_ps(0.0f); auto minValue = _mm256_set1_ps(post->minValue); auto maxValue = _mm256_set1_ps(post->maxValue); auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); auto offset = _mm256_set1_epi32(128); __m256 fp32min, fp32max; if (0 == post->useInt8 && post->fp32minmax) { fp32min = _mm256_set1_ps((post->fp32minmax)[0]); fp32max = _mm256_set1_ps((post->fp32minmax)[1]); } const float* biasPtr = nullptr; if (post->biasFloat) { biasPtr = post->biasFloat; } int inputBlockNum = 1; auto accumbuff = post->accumBuffer; auto blockNum = post->blockNum; if (post->inputBias) { inputBlockNum = blockNum; } int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float) * GEMMINT8_AVX2_H; auto srcKernelSumPtr = post->srcKernelSum; __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3; auto neg128_f = _mm256_set1_ps(-128.f); __m256 extrascale0 = _mm256_setzero_ps(); __m256 extrascale1 = _mm256_setzero_ps(); __m256 extrascale2 = _mm256_setzero_ps(); __m256 extrascale3 = _mm256_setzero_ps(); __m256 extrabias0 = _mm256_setzero_ps(); __m256 extrabias1 = _mm256_setzero_ps(); __m256 extrabias2 = _mm256_setzero_ps(); __m256 extrabias3 = _mm256_setzero_ps(); if (post->inputScale) { if (GEMMINT8_AVX2_E == realDst) { extrascale0 = _mm256_set1_ps(post->inputScale[0]); extrascale1 = _mm256_set1_ps(post->inputScale[1]); extrascale2 = _mm256_set1_ps(post->inputScale[2]); extrascale3 = _mm256_set1_ps(post->inputScale[3]); } else { extrascale0 = _mm256_set1_ps(post->inputScale[0]); if (realDst > 1) { extrascale1 = _mm256_set1_ps(post->inputScale[1]); } if (realDst > 2) { extrascale2 = _mm256_set1_ps(post->inputScale[2]); } } } __m256 bias0, bias1, bias2, bias3; if (GEMMINT8_AVX2_E == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]); kernelSum3 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[3]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); __m256i D03 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0); __m256i D11 = _mm256_set1_epi32(0); __m256i D12 = _mm256_set1_epi32(0); __m256i D13 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = mm_loadu_si128(weight_sz + 16 * 0); auto w1 = mm_loadu_si128(weight_sz + 16 * 1); auto W0 = _mm256_cvtepi8_epi16(w0); auto W1 = _mm256_cvtepi8_epi16(w1); auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2)); auto s3 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 3)); auto S0 = _mm256_cvtepu8_epi16(s0); auto S1 = _mm256_cvtepu8_epi16(s1); auto S2 = _mm256_cvtepu8_epi16(s2); auto S3 = _mm256_cvtepu8_epi16(s3); COMPUTE(0, 0); COMPUTE(1, 0); COMPUTE(0, 1); COMPUTE(1, 1); COMPUTE(0, 2); COMPUTE(1, 2); COMPUTE(0, 3); COMPUTE(1, 3); } auto D0 = NORMAL_HADD(D00, D10); auto D1 = NORMAL_HADD(D01, D11); auto D2 = NORMAL_HADD(D02, D12); auto D3 = NORMAL_HADD(D03, D13); auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); auto f3 = _mm256_cvtepi32_ps(D3); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f3 = _mm256_mul_ps(f3, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]); extrascale3 = _mm256_set1_ps((post->inputScale + bk * realDst)[3]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); f2 = _mm256_mul_ps(f2, extrascale2); f3 = _mm256_mul_ps(f3, extrascale3); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]); extrabias3 = _mm256_set1_ps((post->inputBias + bk * realDst)[3]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); bias2 = _mm256_mul_ps(extrabias2, wsum); bias3 = _mm256_mul_ps(extrabias3, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum); bias3 = _mm256_mul_ps(_mm256_mul_ps(extrascale3, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); f2 = _mm256_add_ps(f2, bias2); f3 = _mm256_add_ps(f3, bias3); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); f3 = _mm256_add_ps(f3, xy0_3); if (post->useInt8 == 1) { if (biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); f3 = _mm256_add_ps(f3, biasValue); } POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); POSTTREAT(3); } else { if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8); auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); f2 = _mm256_add_ps(f2, dstv2); f3 = _mm256_add_ps(f3, dstv3); } if (bk == blockNum - 1) { if (biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); f3 = _mm256_add_ps(f3, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f3 = _mm256_min_ps(f3, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); f3 = _mm256_max_ps(f3, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2); _mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3); } } } } return; } if (3 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0); __m256i D11 = _mm256_set1_epi32(0); __m256i D12 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = mm_loadu_si128(weight_sz + 16 * 0); auto w1 = mm_loadu_si128(weight_sz + 16 * 1); auto W0 = _mm256_cvtepi8_epi16(w0); auto W1 = _mm256_cvtepi8_epi16(w1); auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2)); auto S0 = _mm256_cvtepu8_epi16(s0); auto S1 = _mm256_cvtepu8_epi16(s1); auto S2 = _mm256_cvtepu8_epi16(s2); COMPUTE(0, 0); COMPUTE(1, 0); COMPUTE(0, 1); COMPUTE(1, 1); COMPUTE(0, 2); COMPUTE(1, 2); } auto D0 = NORMAL_HADD(D00, D10); auto D1 = NORMAL_HADD(D01, D11); auto D2 = NORMAL_HADD(D02, D12); auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); f2 = _mm256_mul_ps(f2, extrascale2); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); bias2 = _mm256_mul_ps(extrabias2, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); f2 = _mm256_add_ps(f2, bias2); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); if (post->useInt8 == 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); } POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); } else { if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); f2 = _mm256_add_ps(f2, dstv2); } if (bk == blockNum - 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2); } } } } return; } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]); __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0); __m256i D11 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * weight_step_Y; const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = mm_loadu_si128(weight_sz + 16 * 0); auto w1 = mm_loadu_si128(weight_sz + 16 * 1); auto W0 = _mm256_cvtepi8_epi16(w0); auto W1 = _mm256_cvtepi8_epi16(w1); auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); auto S0 = _mm256_cvtepu8_epi16(s0); auto S1 = _mm256_cvtepu8_epi16(s1); COMPUTE(0, 0); COMPUTE(1, 0); COMPUTE(0, 1); COMPUTE(1, 1); } auto D0 = NORMAL_HADD(D00, D10); auto D1 = NORMAL_HADD(D01, D11); auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]); } f0 = _mm256_mul_ps(f0, extrascale0); f1 = _mm256_mul_ps(f1, extrascale1); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]); bias0 = _mm256_mul_ps(extrabias0, wsum); bias1 = _mm256_mul_ps(extrabias1, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); f1 = _mm256_add_ps(f1, bias1); } } f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); if (post->useInt8 == 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); } POSTTREAT(0); POSTTREAT(1); } else { if (bk > 0) { auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8); f0 = _mm256_add_ps(f0, dstv0); f1 = _mm256_add_ps(f1, dstv1); } if (bk == blockNum - 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); } _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); } else { _mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1); } } } } return; } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; // block's input const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst; // source kernel sum kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]); __m256i D00 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = mm_loadu_si128(weight_sz + 16 * 0); auto w1 = mm_loadu_si128(weight_sz + 16 * 1); auto W0 = _mm256_cvtepi8_epi16(w0); auto W1 = _mm256_cvtepi8_epi16(w1); auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); auto S0 = _mm256_cvtepu8_epi16(s0); COMPUTE(0, 0); COMPUTE(1, 0); } auto D0 = NORMAL_HADD(D00, D10); auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first f0 = _mm256_mul_ps(f0, scaleValue); if (post->inputScale) { if (post->inputBias) { extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]); } f0 = _mm256_mul_ps(f0, extrascale0); if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) { if (post->inputBias) { auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]); bias0 = _mm256_mul_ps(extrabias0, wsum); } else if (bk == blockNum - 1) { // if input not block quant, only accum once! auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H; auto wsum = _mm256_loadu_ps(wsumDz); bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum); } f0 = _mm256_add_ps(f0, bias0); } } f0 = _mm256_add_ps(f0, xy0_0); if (post->useInt8 == 1) { if (nullptr != biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); } POSTTREAT(0); } else { if (bk > 0) { auto dstv = _mm256_loadu_ps(((float*)accum_x)); f0 = _mm256_add_ps(f0, dstv); } if (bk == blockNum - 1) { if (biasPtr) { const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f0 = _mm256_max_ps(f0, fp32min); } _mm256_storeu_ps(((float*)dst_x), f0); } else { _mm256_storeu_ps(((float*)accum_x) , f0); } } } } return; } } void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { const auto dst_step_tmp = dst_step / sizeof(int8_t); auto zero128 = _mm256_set1_ps(0.0f); auto minValue = _mm256_set1_ps(post->minValue); auto maxValue = _mm256_set1_ps(post->maxValue); auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); auto oneValue = _mm256_set1_epi16(1); auto offset = _mm256_set1_epi32(128); __m256 fp32min, fp32max; if (0 == post->useInt8) { fp32min = _mm256_set1_ps((post->fp32minmax)[0]); fp32max = _mm256_set1_ps((post->fp32minmax)[1]); } auto srcKernelSumPtr = post->srcKernelSum; __m256 kernelSum0 = _mm256_setzero_ps(); __m256 kernelSum1 = _mm256_setzero_ps(); __m256 kernelSum2 = _mm256_setzero_ps(); __m256 kernelSum3 = _mm256_setzero_ps(); if (GEMMINT8_AVX2_E == realDst) { kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]); } else { kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); if (realDst > 1) { kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); } if (realDst > 2) { kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); } } int weight_step_Z = src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H; int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); if (GEMMINT8_AVX2_E == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; auto dst_x = dst_z; __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); __m256i D03 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = _mm256_loadu_si256((__m256i*)weight_sz); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2)); auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue)); D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto D2 = D02; auto D3 = D03; auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); auto f3 = _mm256_cvtepi32_ps(D3); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f3 = _mm256_mul_ps(f3, scaleValue); f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); f3 = _mm256_add_ps(f3, xy0_3); auto biasValue = _mm256_loadu_ps(weightBias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); f3 = _mm256_add_ps(f3, biasValue); if (post->useInt8 == 0) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f3 = _mm256_min_ps(f3, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); f3 = _mm256_max_ps(f3, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); } else { POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); POSTTREAT(3); } } return; } if (3 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; auto dst_x = dst_z; __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = _mm256_loadu_si256((__m256i*)weight_sz); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto D2 = D02; // auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); // D0 = _mm256_add_epi32(D0, biasValue0); // D1 = _mm256_add_epi32(D1, biasValue0); // D2 = _mm256_add_epi32(D2, biasValue0); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); f2 = _mm256_add_ps(f2, xy0_2); auto biasValue = _mm256_loadu_ps(weightBias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); f2 = _mm256_add_ps(f2, biasValue); if (post->useInt8 == 0) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f2 = _mm256_min_ps(f2, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); f2 = _mm256_max_ps(f2, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); } else { POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); } } return; } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; auto dst_x = dst_z; __m256i D00 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = _mm256_loadu_si256((__m256i*)weight_sz); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue)); } auto D0 = D00; auto D1 = D01; auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f0 = _mm256_add_ps(f0, xy0_0); f1 = _mm256_add_ps(f1, xy0_1); auto biasValue = _mm256_loadu_ps(weightBias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); if (post->useInt8 == 0) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_max_ps(f1, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); } else { POSTTREAT(0); POSTTREAT(1); } } return; } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; auto dst_x = dst_z; __m256i D00 = _mm256_set1_epi32(0); for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst; auto w0 = _mm256_loadu_si256((__m256i*)weight_sz); auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0)); D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); } auto D0 = D00; auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); // x_kernelSum x w_quantZero auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first f0 = _mm256_mul_ps(f0, scaleValue); f0 = _mm256_add_ps(f0, xy0_0); auto biasValue = _mm256_loadu_ps(weightBias_dz); f0 = _mm256_add_ps(f0, biasValue); if (post->useInt8 == 0) { f0 = _mm256_min_ps(f0, fp32max); f0 = _mm256_max_ps(f0, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); } else { POSTTREAT(0); } } return; } } #undef MAIN_COMPUTE #undef STORE_TEMP void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) { int pack = 16; auto dst = dstO; auto src = (const int16_t*)srcO; auto weight = (const int16_t*)weightO; auto biasValue0 = _mm256_castps_si256(_mm256_loadu_ps((const float*)parameters->bias)); auto biasValue1 = _mm256_castps_si256(_mm256_loadu_ps((const float*)parameters->bias + 8)); auto scaleValue0 = _mm256_loadu_ps((const float*)parameters->scale); auto scaleValue1 = _mm256_loadu_ps((const float*)parameters->scale + 8); __m256i srcValue0; __m256i zero = _mm256_xor_si256(srcValue0, srcValue0); __m256i d0, d1; int dx, fx, fy; __m256 zero256 = _mm256_set1_ps(0.0f); auto minValue = _mm256_set1_epi16((int16_t)(parameters->minValue + 128)); auto maxValue = _mm256_set1_epi16((int16_t)(parameters->maxValue + 128)); __m256 plus = _mm256_set1_ps(0.5f); __m256 minus = _mm256_set1_ps(-0.5f); auto offset = _mm256_set1_epi32(128); for (dx = 0; dx < width; ++dx) { d0 = biasValue0; d1 = biasValue1; auto dst_x = dst; const auto src_z = src; for (fy = 0; fy < fh; ++fy) { const auto src_y = src_z + fy * dilateY_step; const auto weight_y = weight + fy * fw * pack; for (fx = 0; fx < fw; ++fx) { const auto src_x = src_y + fx * dilateX_step; auto s0_16 = _mm256_castps_si256(_mm256_loadu_ps((float*)src_x)); s0_16 = _mm256_permute4x64_epi64(s0_16, 0xD8); // Reorder 0,1,2,3->0,2,1,3 to ensure s0_32 is 0,1 and s1_32 is 2,3. auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero); auto s1_32 = _mm256_unpackhi_epi16(s0_16, zero); const auto weight_x = weight_y + pack * fx; auto w0_16 = _mm256_castps_si256(_mm256_loadu_ps((float*)weight_x)); w0_16 = _mm256_permute4x64_epi64(w0_16, 0xD8); auto w0_32 = _mm256_unpacklo_epi16(w0_16, zero); auto w1_32 = _mm256_unpackhi_epi16(w0_16, zero); d0 = _mm256_add_epi32(d0, _mm256_madd_epi16(w0_32, s0_32)); d1 = _mm256_add_epi32(d1, _mm256_madd_epi16(w1_32, s1_32)); } } __m256 f0 = _mm256_cvtepi32_ps(d0); __m256 f1 = _mm256_cvtepi32_ps(d1); f0 = _mm256_mul_ps(f0, scaleValue0); f1 = _mm256_mul_ps(f1, scaleValue1); auto m0 = _mm256_cmp_ps(f0, zero256, 1); auto m1 = _mm256_cmp_ps(f1, zero256, 1); m0 = _mm256_blendv_ps(plus, minus, m0); m1 = _mm256_blendv_ps(plus, minus, m1); f0 = _mm256_add_ps(f0, m0); f1 = _mm256_add_ps(f1, m1); // _MM_FROUND_TO_ZERO d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3)); d1 = _mm256_cvtps_epi32(_mm256_round_ps(f1, 3)); d0 = _mm256_add_epi32(d0, offset); d1 = _mm256_add_epi32(d1, offset); d0 = _mm256_permute4x64_epi64(_mm256_packs_epi32(d0, d1), 0xD8); d0 = _mm256_min_epi16(d0, maxValue); d0 = _mm256_max_epi16(d0, minValue); auto y256i = _mm256_permute4x64_epi64(_mm256_packus_epi16(d0, _mm256_setzero_si256()), 0xD8); auto y128 = _mm_castsi128_ps(_mm256_extracti128_si256(y256i, 0)); _mm_storeu_ps((float*)dst, y128); dst += 16; src += src_w_step; } } void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec) { auto zero = _mm256_set1_epi32(0); auto minValue = _mm256_set1_ps(minV); auto maxValue = _mm256_set1_ps(maxV); auto zeroPointValue = _mm256_set1_ps(zeroPoint[0]); auto offset = _mm256_set1_epi32(128); auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); auto scaleValue = _mm256_set1_ps(scalep[0]); if (quanParamVec & 1) { scaleValue = _mm256_loadu_ps(scalep); } if (quanParamVec >> 1) { zeroPointValue = _mm256_loadu_ps(zeroPoint); } for (int i = 0; i < sizeQuad; ++i) { auto f0 = _mm256_loadu_ps(src + 8 * i); f0 = _mm256_mul_ps(f0, scaleValue); f0 = _mm256_add_ps(f0, zeroPointValue); f0 = _mm256_min_ps(f0, maxValue); f0 = _mm256_max_ps(f0, minValue); auto m0 = _mm256_cmp_ps(f0, _mm256_castsi256_ps(zero), 1); m0 = _mm256_blendv_ps(plus, minus, m0); f0 = _mm256_add_ps(f0, m0); // 3: _MM_FROUND_TO_ZERO auto d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3)); d0 = _mm256_add_epi32(d0, offset); d0 = _mm256_packs_epi32(d0, _mm256_setzero_si256()); d0 = _mm256_permute4x64_epi64(d0, 0xD8); auto x = _mm256_packus_epi16(d0, _mm256_setzero_si256()); *((int64_t*)dst + i) = _mm256_extract_epi64(x, 0); } } void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec) { auto sizeC4 = sizeQuad / 4; auto sizeRemain = sizeQuad % 4; auto zero = _mm256_set1_epi32(0); auto scaleValue = _mm256_set1_ps(scale[0]); auto zeroPointValue = _mm256_set1_ps(zeroPoint[0] + 128.f); if (quanParamVec & 1) { scaleValue = _mm256_loadu_ps(scale); } if (quanParamVec >> 1) { zeroPointValue = _mm256_add_ps(_mm256_loadu_ps(zeroPoint), _mm256_set1_ps(128.f)); } for (int i = 0; i < sizeC4; ++i) { auto s = _mm256_castps_si256(_mm256_loadu_ps((const float*)(src))); auto s0_16 = _mm256_permute4x64_epi64(_mm256_unpacklo_epi8(s, zero), 0XD8); auto s1_16 = _mm256_permute4x64_epi64(_mm256_unpackhi_epi8(s, zero), 0xD8); auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero); auto s1_32 = _mm256_unpacklo_epi16(s1_16, zero); auto s2_32 = _mm256_unpackhi_epi16(s0_16, zero); auto s3_32 = _mm256_unpackhi_epi16(s1_16, zero); auto s0_f = _mm256_cvtepi32_ps(s0_32); auto s1_f = _mm256_cvtepi32_ps(s1_32); auto s2_f = _mm256_cvtepi32_ps(s2_32); auto s3_f = _mm256_cvtepi32_ps(s3_32); s0_f = _mm256_sub_ps(s0_f, zeroPointValue); s1_f = _mm256_sub_ps(s1_f, zeroPointValue); s2_f = _mm256_sub_ps(s2_f, zeroPointValue); s3_f = _mm256_sub_ps(s3_f, zeroPointValue); _mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 2, _mm256_mul_ps(s2_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 3, _mm256_mul_ps(s3_f, scaleValue)); src += 32; dst += 32; } if (sizeRemain > 0) { int8_t srcTemp[256]; ::memcpy(srcTemp, src, sizeRemain * 8); auto s = _mm256_castps_si256(_mm256_loadu_ps((const float*)(srcTemp))); auto s0_16 = _mm256_permute4x64_epi64(_mm256_unpacklo_epi8(s, zero), 0XD8); auto s1_16 = _mm256_permute4x64_epi64(_mm256_unpackhi_epi8(s, zero), 0xD8); auto s0_32 = _mm256_unpacklo_epi16(s0_16, zero); auto s1_32 = _mm256_unpacklo_epi16(s1_16, zero); auto s2_32 = _mm256_unpackhi_epi16(s0_16, zero); auto s3_32 = _mm256_unpackhi_epi16(s1_16, zero); auto s0_f = _mm256_cvtepi32_ps(s0_32); auto s1_f = _mm256_cvtepi32_ps(s1_32); auto s2_f = _mm256_cvtepi32_ps(s2_32); auto s3_f = _mm256_cvtepi32_ps(s3_32); s0_f = _mm256_sub_ps(s0_f, zeroPointValue); s1_f = _mm256_sub_ps(s1_f, zeroPointValue); s2_f = _mm256_sub_ps(s2_f, zeroPointValue); s3_f = _mm256_sub_ps(s3_f, zeroPointValue); switch (sizeRemain) { case 3: _mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 2, _mm256_mul_ps(s2_f, scaleValue)); break; case 2: _mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue)); _mm256_storeu_ps(dst + 8 * 1, _mm256_mul_ps(s1_f, scaleValue)); break; case 1: _mm256_storeu_ps(dst + 8 * 0, _mm256_mul_ps(s0_f, scaleValue)); break; default: break; } } } static void _AVX2_MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { *UNIT = GEMMINT8_AVX2_H; *SRC_UNIT = GEMMINT8_AVX2_L; *DST_XUNIT = GEMMINT8_AVX2_E; } static void _AVXMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) { int number = info[0]; int eReal = info[1]; int xStride = info[3]; int xS4 = xStride * AVX2_PACKINT8 / sizeof(int32_t); int PUNIT = AVX2_PACKINT8 / GEMMINT8_AVX2_L; int FLOATPACK = AVX2_PACKINT8 / sizeof(int32_t); int eOutsideStride = info[2] / sizeof(int32_t); const int EP = GEMMINT8_AVX2_E; int eDest = EP; const int LP = GEMMINT8_AVX2_L; int realDstCount = info[4]; for (int n=0; n<number; ++n) { int e = el[4 * n + 0]; int l = el[4 * n + 1]; int eOffset = el[4 * n + 2]; int lOffset = el[4 * n + 3]; int eC = eOffset / EP; int eR = eOffset % EP; int eS = eDest - eR; bool lastBag = false; int eOutsideStride4LastBag = eOutsideStride; if (realDstCount % EP > 0) { int jobsE = realDstCount - eOffset - e; if (jobsE == 0 || (jobsE < (realDstCount % EP))) { lastBag = true; } } auto source = (int32_t*)sourceGroup[n]; auto dest = (int32_t*)(destOrigin + eC * info[2] + eR * LP + lOffset * EP); //printf("e=%d, l=%d, eOffset=%d, lOffset=%d, eDest=%d\n", e, l, eOffset, lOffset, eDest); l = l / 4; // Use float instead of int8 * 4 if (lastBag && e + eR < EP) { int elast = ALIMAX(eR + e, realDstCount % EP); dest = (int32_t*)(destOrigin + lOffset * elast + eC * info[2] + eR * LP); } int offsetLC = lOffset / 4; for (int x = 0; x < l; ++x) { int eRemain = e; auto xR = x % PUNIT; auto xC = x / PUNIT; auto d = dest; auto s = source + xC * eReal * FLOATPACK + xR; if (eR > 0) { int eStep = ALIMIN(eRemain, eS); for (int yi=0; yi<eStep; ++yi) { d[yi] = s[yi * xS4]; } eRemain-=eStep; if (!lastBag ||eRemain >= EP) { d += (eOutsideStride - eR); } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d += (eOutsideStride4LastBag - eR + offsetLC * eFill); } s += eS * xS4; } while (eRemain > 0) { int eStep = ALIMIN(eDest, eRemain); for (int yi=0; yi<eStep; ++yi) { d[yi] = s[yi * xS4]; } eRemain-=eStep; if (!lastBag || eRemain >= EP) { d+= eOutsideStride; } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d+= (eOutsideStride4LastBag + offsetLC * eFill); } s+= eStep * xS4; } if (lastBag && e + eR < EP) { int efill = ALIMAX(e + eR, realDstCount % EP); dest += efill; } else { dest += eDest; } offsetLC++; } } } void _AVX_MNNInt8FunctionInit(void* functions) { auto gAVX2CoreInt8Functions = (MNN::CoreInt8Functions*)functions; // MatMul gAVX2CoreInt8Functions->Int8GemmKernel = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit; gAVX2CoreInt8Functions->Int8GemmKernelFast = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast; gAVX2CoreInt8Functions->MNNGetGemmUnit = _AVX2_MNNGetGemmUnit; gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _AVXMNNPackC4ForMatMul_A; #ifdef MNN_LOW_MEMORY gAVX2CoreInt8Functions->Int8GemmKernel_W4 = _AVX_MNNGemmInt8AddBiasScale_16x4_w4; #endif // Int8 <-> Float gAVX2CoreInt8Functions->MNNFloat2Int8 = _AVX_MNNFloat2Int8; gAVX2CoreInt8Functions->MNNInt8ScaleToFloat = _AVX_MNNInt8ScaleToFloat; // conv depthwise gAVX2CoreInt8Functions->ConvDepthwiseLineInt8 = _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit; }