source/backend/cpu/x86_x64/avx/GemmFunction.hpp (2,745 lines of code) (raw):

// // GemmFunction.hpp // MNN // // Created by MNN on 2020/09/22. // Copyright © 2018, Alibaba Group Holding Limited // #define MNN_UNIT_E 24 #define TRANPOSE_SAVE(u, v, z0, z3, z6, z9) \ { \ auto m0 = _mm256_extractf128_ps(z0, u); \ auto m1 = _mm256_extractf128_ps(z3, u); \ auto m2 = _mm256_extractf128_ps(z6, u); \ auto m3 = _mm256_extractf128_ps(z9, u); \ _MM_TRANSPOSE4_PS(m0, m1, m2, m3); \ STORE_4(dst + 8 * (0 + 4 * u + 8 * v), m0); \ STORE_4(dst + 8 * (1 + 4 * u + 8 * v), m1); \ STORE_4(dst + 8 * (2 + 4 * u + 8 * v), m2); \ STORE_4(dst + 8 * (3 + 4 * u + 8 * v), m3); \ } #define FMLA_TRANSPOSE_SAVE(u, v, z0, z3, z6, z9) \ { \ auto tmp_m0 = LOAD4(dst + 8 * (0 + 4 * u + 8 * v)); \ auto tmp_m1 = LOAD4(dst + 8 * (1 + 4 * u + 8 * v)); \ auto tmp_m2 = LOAD4(dst + 8 * (2 + 4 * u + 8 * v)); \ auto tmp_m3 = LOAD4(dst + 8 * (3 + 4 * u + 8 * v)); \ auto m0 = _mm256_extractf128_ps(z0, u); \ auto m1 = _mm256_extractf128_ps(z3, u); \ auto m2 = _mm256_extractf128_ps(z6, u); \ auto m3 = _mm256_extractf128_ps(z9, u); \ _MM_TRANSPOSE4_PS(m0, m1, m2, m3); \ m0 = _mm_add_ps(tmp_m0, m0); \ m1 = _mm_add_ps(tmp_m1, m1); \ m2 = _mm_add_ps(tmp_m2, m2); \ m3 = _mm_add_ps(tmp_m3, m3); \ STORE_4(dst + 8 * (0 + 4 * u + 8 * v), m0); \ STORE_4(dst + 8 * (1 + 4 * u + 8 * v), m1); \ STORE_4(dst + 8 * (2 + 4 * u + 8 * v), m2); \ STORE_4(dst + 8 * (3 + 4 * u + 8 * v), m3); \ } namespace { static inline __m128i mm_loadu_si128(const void* addr) { return _mm_castps_si128(LOAD4((const float*)addr)); } static inline __m256i mm256_broadcastsi128_si256(const void* addr) { return _mm256_broadcastsi128_si256(mm_loadu_si128(addr)); } } // namespace // #define INIT_MAIN_24_4 \ auto s0 = LOAD8(A + 0 * 24); \ auto s1 = LOAD8(A + 0 * 24 + 8); \ auto s2 = LOAD8(A + 0 * 24 + 16); \ auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \ auto z0 = _mm256_mul_ps(s0, w0); \ auto z1 = _mm256_mul_ps(s1, w0); \ auto z2 = _mm256_mul_ps(s2, w0); \ auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \ auto z3 = _mm256_mul_ps(s0, w1); \ auto z4 = _mm256_mul_ps(s1, w1); \ auto z5 = _mm256_mul_ps(s2, w1); \ auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \ auto z6 = _mm256_mul_ps(s0, w2); \ auto z7 = _mm256_mul_ps(s1, w2); \ auto z8 = _mm256_mul_ps(s2, w2); \ auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \ auto z9 = _mm256_mul_ps(s0, w3); \ auto z10 = _mm256_mul_ps(s1, w3); \ auto z11 = _mm256_mul_ps(s2, w3); #define COMPUTE_24_4 \ s0 = LOAD8(A + sy * 24); \ s1 = LOAD8(A + sy * 24 + 8); \ s2 = LOAD8(A + sy * 24 + 16); \ w0 = BROAD_LOAD(weight + sy * 4 + 0); \ z0 = MNNAVXFMA(s0, w0, z0); \ z1 = MNNAVXFMA(s1, w0, z1); \ z2 = MNNAVXFMA(s2, w0, z2); \ w1 = BROAD_LOAD(weight + sy * 4 + 1); \ z3 = MNNAVXFMA(s0, w1, z3); \ z4 = MNNAVXFMA(s1, w1, z4); \ z5 = MNNAVXFMA(s2, w1, z5); \ w2 = BROAD_LOAD(weight + sy * 4 + 2); \ z6 = MNNAVXFMA(s0, w2, z6); \ z7 = MNNAVXFMA(s1, w2, z7); \ z8 = MNNAVXFMA(s2, w2, z8); \ w3 = BROAD_LOAD(weight + sy * 4 + 3); \ z9 = MNNAVXFMA(s0, w3, z9); \ z10 = MNNAVXFMA(s1, w3, z10); \ z11 = MNNAVXFMA(s2, w3, z11); template <typename TYPE> static void _AVX_MNNPackedMatMul_Main(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); INIT_MAIN_24_4; for (int sy = 1; sy < l; ++sy) { COMPUTE_24_4; } TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); TRANPOSE_SAVE(1, 2, z2, z5, z8, z11); } } #define EXPAND_128(x) _mm256_castsi256_ps(_mm256_broadcastsi128_si256(_mm_castps_si128((x)))) // #define INIT_MAIN_20_4 \ auto s0 = LOAD8(A + 0 * aStride); \ auto s1 = LOAD8(A + 0 * aStride + 8); \ auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); \ auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \ auto z0 = _mm256_mul_ps(s0, w0); \ auto z1 = _mm256_mul_ps(s1, w0); \ auto z2 = _mm256_mul_ps(s2, w0); \ auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \ auto z3 = _mm256_mul_ps(s0, w1); \ auto z4 = _mm256_mul_ps(s1, w1); \ auto z5 = _mm256_mul_ps(s2, w1); \ auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \ auto z6 = _mm256_mul_ps(s0, w2); \ auto z7 = _mm256_mul_ps(s1, w2); \ auto z8 = _mm256_mul_ps(s2, w2); \ auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \ auto z9 = _mm256_mul_ps(s0, w3); \ auto z10 = _mm256_mul_ps(s1, w3); \ auto z11 = _mm256_mul_ps(s2, w3); #define COMPUTE_20_4 \ s0 = LOAD8(A + sy * aStride); \ s1 = LOAD8(A + sy * aStride + 8); \ s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); \ w0 = BROAD_LOAD(weight + sy * 4 + 0); \ z0 = MNNAVXFMA(s0, w0, z0); \ z1 = MNNAVXFMA(s1, w0, z1); \ z2 = MNNAVXFMA(s2, w0, z2); \ w1 = BROAD_LOAD(weight + sy * 4 + 1); \ z3 = MNNAVXFMA(s0, w1, z3); \ z4 = MNNAVXFMA(s1, w1, z4); \ z5 = MNNAVXFMA(s2, w1, z5); \ w2 = BROAD_LOAD(weight + sy * 4 + 2); \ z6 = MNNAVXFMA(s0, w2, z6); \ z7 = MNNAVXFMA(s1, w2, z7); \ z8 = MNNAVXFMA(s2, w2, z8); \ w3 = BROAD_LOAD(weight + sy * 4 + 3); \ z9 = MNNAVXFMA(s0, w3, z9); \ z10 = MNNAVXFMA(s1, w3, z10); \ z11 = MNNAVXFMA(s2, w3, z11); template <typename TYPE> static void _AVX_MNNPackedMatMul_20(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); INIT_MAIN_20_4; for (int sy = 1; sy < l; ++sy) { COMPUTE_20_4; } TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); } } #define INIT_MAIN_16_4 \ auto s0 = LOAD8(A + 0 * aStride); \ auto s1 = LOAD8(A + 0 * aStride + 8); \ auto w0 = BROAD_LOAD(weight + 0 * 4 + 0); \ auto z0 = _mm256_mul_ps(s0, w0); \ auto z1 = _mm256_mul_ps(s1, w0); \ auto w1 = BROAD_LOAD(weight + 0 * 4 + 1); \ auto z3 = _mm256_mul_ps(s0, w1); \ auto z4 = _mm256_mul_ps(s1, w1); \ auto w2 = BROAD_LOAD(weight + 0 * 4 + 2); \ auto z6 = _mm256_mul_ps(s0, w2); \ auto z7 = _mm256_mul_ps(s1, w2); \ auto w3 = BROAD_LOAD(weight + 0 * 4 + 3); \ auto z9 = _mm256_mul_ps(s0, w3); \ auto z10 = _mm256_mul_ps(s1, w3); #define COMPUTE_16_4 \ s0 = LOAD8(A + sy * aStride); \ s1 = LOAD8(A + sy * aStride + 8); \ w0 = BROAD_LOAD(weight + sy * 4 + 0); \ z0 = MNNAVXFMA(s0, w0, z0); \ z1 = MNNAVXFMA(s1, w0, z1); \ w1 = BROAD_LOAD(weight + sy * 4 + 1); \ z3 = MNNAVXFMA(s0, w1, z3); \ z4 = MNNAVXFMA(s1, w1, z4); \ w2 = BROAD_LOAD(weight + sy * 4 + 2); \ z6 = MNNAVXFMA(s0, w2, z6); \ z7 = MNNAVXFMA(s1, w2, z7); \ w3 = BROAD_LOAD(weight + sy * 4 + 3); \ z9 = MNNAVXFMA(s0, w3, z9); \ z10 = MNNAVXFMA(s1, w3, z10); template <typename TYPE> static void _AVX_MNNPackedMatMul_16(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); INIT_MAIN_16_4; for (int sy = 1; sy < l; ++sy) { COMPUTE_16_4; } TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); } } #define DST_ADDR_UNPACK4(x)\ auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8;\ auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4;\ auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8;\ auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4;\ template <typename TYPE> static void _AVX_MNNPackedMatMul_5(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { auto weight0 = B + (hC4Unit * y + 0) * bStride; auto weight1 = B + (hC4Unit * y + 1) * bStride; auto weight2 = B + (hC4Unit * y + 2) * bStride; auto weight3 = B + (hC4Unit * y + 3) * bStride; DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto sumAvx40 = _mm256_setzero_ps(); auto sumAvx41 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto S4 = BROAD_LOAD(srcUse + 4); auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1)); auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1)); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40); sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst0 + 32, sumAvx40); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); STORE_8(dst2 + 32, sumAvx41); } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3); auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4); auto w0 = LOAD4(weight + 0 * 4); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); auto z3 = _mm_mul_ps(s3, w0); auto z4 = _mm_mul_ps(s4, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); s3 = BROAD_LOAD_4(A + sy * aStride + 3); s4 = BROAD_LOAD_4(A + sy * aStride + 4); w0 = LOAD4(weight + sy * 4); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); z3 = MNNSSEFMA(s3, w0, z3); z4 = MNNSSEFMA(s4, w0, z4); } STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); STORE_4(dst + 8 * 3, z3); STORE_4(dst + 8 * 4, z4); } } template <typename TYPE> static void _AVX_MNNPackedMatMul_3(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { auto weight0 = B + (hC4Unit * y + 0) * bStride; auto weight1 = B + (hC4Unit * y + 1) * bStride; auto weight2 = B + (hC4Unit * y + 2) * bStride; auto weight3 = B + (hC4Unit * y + 3) * bStride; auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1)); auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1)); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1)); } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto w0 = LOAD4(weight + 0 * 4); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); w0 = LOAD4(weight + sy * 4); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); } STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); } } template <typename TYPE> static void _AVX_MNNPackedMatMul_2(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { auto weight0 = B + (hC4Unit * y + 0) * bStride; auto weight1 = B + (hC4Unit * y + 1) * bStride; auto weight2 = B + (hC4Unit * y + 2) * bStride; auto weight3 = B + (hC4Unit * y + 3) * bStride; auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1)); auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1)); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto w0 = LOAD4(weight + 0 * 4); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); w0 = LOAD4(weight + sy * 4); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); } STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); } } template <typename TYPE> static void _AVX_MNNPackedMatMul_4(TYPE* C, const TYPE* A, const TYPE* B, const size_t* parameter) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { auto weight0 = B + (hC4Unit * y + 0) * bStride; auto weight1 = B + (hC4Unit * y + 1) * bStride; auto weight2 = B + (hC4Unit * y + 2) * bStride; auto weight3 = B + (hC4Unit * y + 3) * bStride; DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto W0 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight0), mm_loadu_si128(weight1), 1)); auto W1 = _mm256_castsi256_ps(_mm256_insertf128_si256(mm256_broadcastsi128_si256(weight2), mm_loadu_si128(weight3), 1)); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto s0 = LOAD4(A + 0 * aStride); auto w0 = BROAD_LOAD_4(weight + 0 * 4 + 0); auto w1 = BROAD_LOAD_4(weight + 0 * 4 + 1); auto w2 = BROAD_LOAD_4(weight + 0 * 4 + 2); auto w3 = BROAD_LOAD_4(weight + 0 * 4 + 3); auto z0 = _mm_mul_ps(s0, w0); auto z3 = _mm_mul_ps(s0, w1); auto z6 = _mm_mul_ps(s0, w2); auto z9 = _mm_mul_ps(s0, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD4(A + sy * aStride); w0 = BROAD_LOAD_4(weight + sy * 4 + 0); w1 = BROAD_LOAD_4(weight + sy * 4 + 1); w2 = BROAD_LOAD_4(weight + sy * 4 + 2); w3 = BROAD_LOAD_4(weight + sy * 4 + 3); z0 = MNNSSEFMA(s0, w0, z0); z3 = MNNSSEFMA(s0, w1, z3); z6 = MNNSSEFMA(s0, w2, z6); z9 = MNNSSEFMA(s0, w3, z9); } _MM_TRANSPOSE4_PS(z0, z3, z6, z9); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z3); STORE_4(dst + 8 * 2, z6); STORE_4(dst + 8 * 3, z9); } } template <typename TYPE> static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE* B, size_t eSize, const size_t* parameter) { auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(TYPE); auto bStride = bExtraStride + l * 4; auto hC4 = UP_DIV(h, 4); auto es = eSize; auto oC = C; auto aStride = parameter[0] / sizeof(TYPE); if (eSize >= 20) { _AVX_MNNPackedMatMul_20<TYPE>(C, A, B, parameter); eSize -= 20; C += 20 * 8; A += 20; } if (eSize >= 16) { _AVX_MNNPackedMatMul_16<TYPE>(C, A, B, parameter); eSize -= 16; C += 16 * 8; A += 16; } while (eSize >= 5) { _AVX_MNNPackedMatMul_5<TYPE>(C, A, B, parameter); eSize -= 5; C += 5 * 8; A += 5; } if (eSize == 4) { _AVX_MNNPackedMatMul_4<TYPE>(C, A, B, parameter); return; } if (eSize == 3) { _AVX_MNNPackedMatMul_3<TYPE>(C, A, B, parameter); return; } if (eSize == 2) { _AVX_MNNPackedMatMul_2<TYPE>(C, A, B, parameter); return; } if (eSize == 0) { return; } int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; int x = 0; for (int y = 0; y < hC16; ++y) { auto weight0 = B + (hC4Unit * y + 0) * bStride; auto weight1 = B + (hC4Unit * y + 1) * bStride; auto weight2 = B + (hC4Unit * y + 2) * bStride; auto weight3 = B + (hC4Unit * y + 3) * bStride; auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8; auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4; auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8; auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4; auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W00 = LOAD8(weight0 + 16 * sy + 0); auto W01 = LOAD8(weight0 + 16 * sy + 8); auto W10 = LOAD8(weight1 + 16 * sy + 0); auto W11 = LOAD8(weight1 + 16 * sy + 8); auto W20 = LOAD8(weight2 + 16 * sy + 0); auto W21 = LOAD8(weight2 + 16 * sy + 8); auto W30 = LOAD8(weight3 + 16 * sy + 0); auto W31 = LOAD8(weight3 + 16 * sy + 8); sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00); sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01); sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11); sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20); sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21); sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30); sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31); srcUse += 4 * aStride; } sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01); sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11); sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21); sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31); auto sum00 = _mm256_extractf128_ps(sumAvx00, 0); auto sum01 = _mm256_extractf128_ps(sumAvx00, 1); auto sum0 = _mm_add_ps(sum00, sum01); auto sum10 = _mm256_extractf128_ps(sumAvx10, 0); auto sum11 = _mm256_extractf128_ps(sumAvx10, 1); auto sum1 = _mm_add_ps(sum10, sum11); auto sum20 = _mm256_extractf128_ps(sumAvx20, 0); auto sum21 = _mm256_extractf128_ps(sumAvx20, 1); auto sum2 = _mm_add_ps(sum20, sum21); auto sum30 = _mm256_extractf128_ps(sumAvx30, 0); auto sum31 = _mm256_extractf128_ps(sumAvx30, 1); auto sum3 = _mm_add_ps(sum30, sum31); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w0 = LOAD4(weight0 + 4 * sy); auto w1 = LOAD4(weight1 + 4 * sy); auto w2 = LOAD4(weight2 + 4 * sy); auto w3 = LOAD4(weight3 + 4 * sy); sum0 = MNNSSEFMA(s, w0, sum0); sum1 = MNNSSEFMA(s, w1, sum1); sum2 = MNNSSEFMA(s, w2, sum2); sum3 = MNNSSEFMA(s, w3, sum3); srcUse += aStride; } STORE_4(dst0, sum0); STORE_4(dst1, sum1); STORE_4(dst2, sum2); STORE_4(dst3, sum3); } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2); auto sumAvx0 = _mm256_setzero_ps(); auto sumAvx1 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W0 = LOAD8(weight + 16 * sy + 0); auto W1 = LOAD8(weight + 16 * sy + 8); sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0); sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1); srcUse += 4 * aStride; } sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1); auto sum0 = _mm256_extractf128_ps(sumAvx0, 0); auto sum1 = _mm256_extractf128_ps(sumAvx0, 1); auto sum = _mm_add_ps(sum0, sum1); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w = LOAD4(weight + 4 * sy); sum = MNNSSEFMA(s, w, sum); srcUse += aStride; } STORE_4(dst, sum); } } #ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM //----------------------- MatMul(float, int4) Functions ---------------------------// #define LOAD_WEIGHT_ALPHA_BIAS_int4x4 \ auto weight0 = B + (hC4Unit * y + 0) * bStride / 2;\ auto weight1 = B + (hC4Unit * y + 1) * bStride / 2;\ auto weight2 = B + (hC4Unit * y + 2) * bStride / 2;\ auto weight3 = B + (hC4Unit * y + 3) * bStride / 2;\ auto alpha0 = _mm_loadu_ps(k + y * 16 + 0);\ auto alpha1 = _mm_loadu_ps(k + y * 16 + 4);\ auto alpha2 = _mm_loadu_ps(k + y * 16 + 8);\ auto alpha3 = _mm_loadu_ps(k + y * 16 + 12);\ auto bias0 = _mm_loadu_ps(b + y * 16 + 0);\ auto bias1 = _mm_loadu_ps(b + y * 16 + 4);\ auto bias2 = _mm_loadu_ps(b + y * 16 + 8);\ auto bias3 = _mm_loadu_ps(b + y * 16 + 12); #define LOAD_ALPHA_BIAS_DOUBLE \ auto alpha0_2 = _mm256_set_m128(alpha0, alpha0);\ auto alpha1_2 = _mm256_set_m128(alpha1, alpha1);\ auto alpha2_2 = _mm256_set_m128(alpha2, alpha2);\ auto alpha3_2 = _mm256_set_m128(alpha3, alpha3);\ auto bias0_2 = _mm256_set_m128(bias0, bias0);\ auto bias1_2 = _mm256_set_m128(bias1, bias1);\ auto bias2_2 = _mm256_set_m128(bias2, bias2);\ auto bias3_2 = _mm256_set_m128(bias3, bias3); static inline __m128 _load_int4x4(const uint8_t* src, __m128 alpha, __m128 bias) { auto w01 = src[0]; auto w23 = src[1]; int iw01 = w01; int iw23 = w23; int iw0 = iw01 / 16; int iw1 = iw01 % 16; int iw2 = iw23 / 16; int iw3 = iw23 % 16; auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); return ws; } static inline __m256 _load_int4x8(const uint8_t* src, __m256 alpha, __m256 bias) { float w[8]; for (int i = 0; i < 4; i++) { int x = src[i]; int a = x / 16; int b = x % 16; w[i * 2] = a; w[i * 2 + 1] = b; } auto w8 = LOAD8(w); return _mm256_add_ps(_mm256_mul_ps(w8, alpha), bias); } template <typename TYPE> static void _AVX_MNNPackedMatMul_Main_int4(TYPE* C, const TYPE* A, const TYPE* fB, const size_t* parameter, const float* k, const float* b) { auto B = reinterpret_cast<const uint8_t*>(fB); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; size_t blockId = parameter[6]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * 24); auto s1 = LOAD8(A + 0 * 24 + 8); auto s2 = LOAD8(A + 0 * 24 + 16); auto ws = _load_int4x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z2 = _mm256_mul_ps(s2, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z5 = _mm256_mul_ps(s2, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z8 = _mm256_mul_ps(s2, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); auto z11 = _mm256_mul_ps(s2, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * 24); s1 = LOAD8(A + sy * 24 + 8); s2 = LOAD8(A + sy * 24 + 16); ws = _load_int4x4(weight + sy * 2, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z2 = MNNAVXFMA(s2, w0, z2); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z5 = MNNAVXFMA(s2, w1, z5); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z8 = MNNAVXFMA(s2, w2, z8); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); z11 = MNNAVXFMA(s2, w3, z11); } if (blockId == 0) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); TRANPOSE_SAVE(1, 2, z2, z5, z8, z11); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(0, 2, z2, z5, z8, z11); FMLA_TRANSPOSE_SAVE(1, 2, z2, z5, z8, z11); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_20(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * aStride); auto s1 = LOAD8(A + 0 * aStride + 8); auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); auto ws = _load_int4x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z2 = _mm256_mul_ps(s2, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z5 = _mm256_mul_ps(s2, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z8 = _mm256_mul_ps(s2, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); auto z11 = _mm256_mul_ps(s2, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * aStride); s1 = LOAD8(A + sy * aStride + 8); s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); ws = _load_int4x4(weight + sy * 2, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z2 = MNNAVXFMA(s2, w0, z2); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z5 = MNNAVXFMA(s2, w1, z5); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z8 = MNNAVXFMA(s2, w2, z8); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); z11 = MNNAVXFMA(s2, w3, z11); } if (0 == blockId) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(0, 2, z2, z5, z8, z11); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_16(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * aStride); auto s1 = LOAD8(A + 0 * aStride + 8); auto ws = _load_int4x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * aStride); s1 = LOAD8(A + sy * aStride + 8); ws = _load_int4x4(weight + sy * 2, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); } if (0 == blockId) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_5(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int4x4 DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto sumAvx40 = _mm256_setzero_ps(); auto sumAvx41 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto S4 = BROAD_LOAD(srcUse + 4); auto w0 = _load_int4x4(weight0, alpha0, bias0); auto w1 = _load_int4x4(weight1, alpha1, bias1); auto w2 = _load_int4x4(weight2, alpha2, bias2); auto w3 = _load_int4x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40); sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41); srcUse += aStride; weight0 += 2; weight1 += 2; weight2 += 2; weight3 += 2; } if (0 == blockId) { STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst0 + 32, sumAvx40); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); STORE_8(dst2 + 32, sumAvx41); } else { auto tmp0 = LOAD8(dst0); auto tmp1 = LOAD8(dst0 + 8); auto tmp2 = LOAD8(dst0 + 16); auto tmp3 = LOAD8(dst0 + 24); auto tmp4 = LOAD8(dst0 + 32); auto tmp5 = LOAD8(dst2); auto tmp6 = LOAD8(dst2 + 8); auto tmp7 = LOAD8(dst2 + 16); auto tmp8 = LOAD8(dst2 + 24); auto tmp9 = LOAD8(dst2 + 32); sumAvx00 = _mm256_add_ps(sumAvx00, tmp0); sumAvx10 = _mm256_add_ps(sumAvx10, tmp1); sumAvx20 = _mm256_add_ps(sumAvx20, tmp2); sumAvx30 = _mm256_add_ps(sumAvx30, tmp3); sumAvx40 = _mm256_add_ps(sumAvx40, tmp4); sumAvx01 = _mm256_add_ps(sumAvx01, tmp5); sumAvx11 = _mm256_add_ps(sumAvx11, tmp6); sumAvx21 = _mm256_add_ps(sumAvx21, tmp7); sumAvx31 = _mm256_add_ps(sumAvx31, tmp8); sumAvx41 = _mm256_add_ps(sumAvx41, tmp9); STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst0 + 32, sumAvx40); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); STORE_8(dst2 + 32, sumAvx41); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3); auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4); auto w0 = _load_int4x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); auto z3 = _mm_mul_ps(s3, w0); auto z4 = _mm_mul_ps(s4, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); s3 = BROAD_LOAD_4(A + sy * aStride + 3); s4 = BROAD_LOAD_4(A + sy * aStride + 4); w0 = _load_int4x4(weight + sy * 2, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); z3 = MNNSSEFMA(s3, w0, z3); z4 = MNNSSEFMA(s4, w0, z4); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); STORE_4(dst + 8 * 3, z3); STORE_4(dst + 8 * 4, z4); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); auto tmp3 = LOAD4(dst + 8 * 3); auto tmp4 = LOAD4(dst + 8 * 4); z0 = _mm_add_ps(tmp0, z0); z1 = _mm_add_ps(tmp1, z1); z2 = _mm_add_ps(tmp2, z2); z3 = _mm_add_ps(tmp3, z3); z4 = _mm_add_ps(tmp4, z4); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); STORE_4(dst + 8 * 3, z3); STORE_4(dst + 8 * 4, z4); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_4(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int4x4 DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto w0 = _load_int4x4(weight0, alpha0, bias0); auto w1 = _load_int4x4(weight1, alpha1, bias1); auto w2 = _load_int4x4(weight2, alpha2, bias2); auto w3 = _load_int4x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); srcUse += aStride; weight0 += 2; weight1 += 2; weight2 += 2; weight3 += 2; } if (0 == blockId) { STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); } else { auto tmp0 = LOAD8(dst0); auto tmp1 = LOAD8(dst0 + 8); auto tmp2 = LOAD8(dst0 + 16); auto tmp3 = LOAD8(dst0 + 24); auto tmp5 = LOAD8(dst2); auto tmp6 = LOAD8(dst2 + 8); auto tmp7 = LOAD8(dst2 + 16); auto tmp8 = LOAD8(dst2 + 24); sumAvx00 = _mm256_add_ps(sumAvx00, tmp0); sumAvx10 = _mm256_add_ps(sumAvx10, tmp1); sumAvx20 = _mm256_add_ps(sumAvx20, tmp2); sumAvx30 = _mm256_add_ps(sumAvx30, tmp3); sumAvx01 = _mm256_add_ps(sumAvx01, tmp5); sumAvx11 = _mm256_add_ps(sumAvx11, tmp6); sumAvx21 = _mm256_add_ps(sumAvx21, tmp7); sumAvx31 = _mm256_add_ps(sumAvx31, tmp8); STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); } } float ws_tmp[4]; for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD4(A + 0 * aStride); auto ws = _load_int4x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm_set1_ps(ws_tmp[0]); auto w1 = _mm_set1_ps(ws_tmp[1]); auto w2 = _mm_set1_ps(ws_tmp[2]); auto w3 = _mm_set1_ps(ws_tmp[3]); auto z0 = _mm_mul_ps(s0, w0); auto z3 = _mm_mul_ps(s0, w1); auto z6 = _mm_mul_ps(s0, w2); auto z9 = _mm_mul_ps(s0, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD4(A + sy * aStride); ws = _load_int4x4(weight + sy * 2, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm_set1_ps(ws_tmp[0]); w1 = _mm_set1_ps(ws_tmp[1]); w2 = _mm_set1_ps(ws_tmp[2]); w3 = _mm_set1_ps(ws_tmp[3]); z0 = MNNSSEFMA(s0, w0, z0); z3 = MNNSSEFMA(s0, w1, z3); z6 = MNNSSEFMA(s0, w2, z6); z9 = MNNSSEFMA(s0, w3, z9); } _MM_TRANSPOSE4_PS(z0, z3, z6, z9); if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z3); STORE_4(dst + 8 * 2, z6); STORE_4(dst + 8 * 3, z9); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); auto tmp3 = LOAD4(dst + 8 * 3); z0 = _mm_add_ps(tmp0, z0); z3 = _mm_add_ps(tmp1, z3); z6 = _mm_add_ps(tmp2, z6); z9 = _mm_add_ps(tmp3, z9); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z3); STORE_4(dst + 8 * 2, z6); STORE_4(dst + 8 * 3, z9); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_3(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int4x4 auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto w0 = _load_int4x4(weight0, alpha0, bias0); auto w1 = _load_int4x4(weight1, alpha1, bias1); auto w2 = _load_int4x4(weight2, alpha2, bias2); auto w3 = _load_int4x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); srcUse += aStride; weight0 += 2; weight1 += 2; weight2 += 2; weight3 += 2; } if (0 == blockId) { STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1)); } else { auto tmp00 = LOAD4(dst0 + 0); auto tmp01 = LOAD4(dst0 + 8); auto tmp02 = LOAD4(dst0 + 16); auto tmp10 = LOAD4(dst1 + 0); auto tmp11 = LOAD4(dst1 + 8); auto tmp12 = LOAD4(dst1 + 16); auto tmp20 = LOAD4(dst2 + 0); auto tmp21 = LOAD4(dst2 + 8); auto tmp22 = LOAD4(dst2 + 16); auto tmp30 = LOAD4(dst3 + 0); auto tmp31 = LOAD4(dst3 + 8); auto tmp32 = LOAD4(dst3 + 16); auto sum_tmp00 = _mm256_extractf128_ps(sumAvx00, 0); auto sum_tmp01 = _mm256_extractf128_ps(sumAvx10, 0); auto sum_tmp02 = _mm256_extractf128_ps(sumAvx20, 0); auto sum_tmp10 = _mm256_extractf128_ps(sumAvx00, 1); auto sum_tmp11 = _mm256_extractf128_ps(sumAvx10, 1); auto sum_tmp12 = _mm256_extractf128_ps(sumAvx20, 1); auto sum_tmp20 = _mm256_extractf128_ps(sumAvx01, 0); auto sum_tmp21 = _mm256_extractf128_ps(sumAvx11, 0); auto sum_tmp22 = _mm256_extractf128_ps(sumAvx21, 0); auto sum_tmp30 = _mm256_extractf128_ps(sumAvx01, 1); auto sum_tmp31 = _mm256_extractf128_ps(sumAvx11, 1); auto sum_tmp32 = _mm256_extractf128_ps(sumAvx21, 1); sum_tmp00 = _mm_add_ps(tmp00, sum_tmp00); sum_tmp01 = _mm_add_ps(tmp01, sum_tmp01); sum_tmp02 = _mm_add_ps(tmp02, sum_tmp02); sum_tmp10 = _mm_add_ps(tmp10, sum_tmp10); sum_tmp11 = _mm_add_ps(tmp11, sum_tmp11); sum_tmp12 = _mm_add_ps(tmp12, sum_tmp12); sum_tmp20 = _mm_add_ps(tmp20, sum_tmp20); sum_tmp21 = _mm_add_ps(tmp21, sum_tmp21); sum_tmp22 = _mm_add_ps(tmp22, sum_tmp22); sum_tmp30 = _mm_add_ps(tmp30, sum_tmp30); sum_tmp31 = _mm_add_ps(tmp31, sum_tmp31); sum_tmp32 = _mm_add_ps(tmp32, sum_tmp32); STORE_4(dst0 + 0, sum_tmp00); STORE_4(dst0 + 8, sum_tmp01); STORE_4(dst0 + 16, sum_tmp02); STORE_4(dst1 + 0, sum_tmp10); STORE_4(dst1 + 8, sum_tmp11); STORE_4(dst1 + 16, sum_tmp12); STORE_4(dst2 + 0, sum_tmp20); STORE_4(dst2 + 8, sum_tmp21); STORE_4(dst2 + 16, sum_tmp22); STORE_4(dst3 + 0, sum_tmp30); STORE_4(dst3 + 8, sum_tmp31); STORE_4(dst3 + 16, sum_tmp32); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto w0 = _load_int4x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); w0 = _load_int4x4(weight + sy * 2, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); z0 = _mm_add_ps(tmp0, z0); z1 = _mm_add_ps(tmp1, z1); z2 = _mm_add_ps(tmp2, z2); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int4_2(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int4x4 auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto w0 = _load_int4x4(weight0, alpha0, bias0); auto w1 = _load_int4x4(weight1, alpha1, bias1); auto w2 = _load_int4x4(weight2, alpha2, bias2); auto w3 = _load_int4x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); srcUse += aStride; weight0 += 2; weight1 += 2; weight2 += 2; weight3 += 2; } if (0 == blockId) { STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); } else { auto tmp01 = LOAD4(dst0 + 0); auto tmp02 = LOAD4(dst0 + 8); auto tmp11 = LOAD4(dst1 + 0); auto tmp12 = LOAD4(dst1 + 8); auto tmp21 = LOAD4(dst2 + 0); auto tmp22 = LOAD4(dst2 + 8); auto tmp31 = LOAD4(dst3 + 0); auto tmp32 = LOAD4(dst3 + 8); auto x_tmp01 = _mm256_extractf128_ps(sumAvx00, 0); auto x_tmp02 = _mm256_extractf128_ps(sumAvx10, 0); auto x_tmp11 = _mm256_extractf128_ps(sumAvx00, 1); auto x_tmp12 = _mm256_extractf128_ps(sumAvx10, 1); auto x_tmp21 = _mm256_extractf128_ps(sumAvx01, 0); auto x_tmp22 = _mm256_extractf128_ps(sumAvx11, 0); auto x_tmp31 = _mm256_extractf128_ps(sumAvx01, 1); auto x_tmp32 = _mm256_extractf128_ps(sumAvx11, 1); x_tmp01 = _mm_add_ps(tmp01, x_tmp01); x_tmp02 = _mm_add_ps(tmp02, x_tmp02); x_tmp11 = _mm_add_ps(tmp11, x_tmp11); x_tmp12 = _mm_add_ps(tmp12, x_tmp12); x_tmp21 = _mm_add_ps(tmp21, x_tmp21); x_tmp22 = _mm_add_ps(tmp22, x_tmp22); x_tmp31 = _mm_add_ps(tmp31, x_tmp31); x_tmp32 = _mm_add_ps(tmp32, x_tmp32); STORE_4(dst0 + 0, x_tmp01); STORE_4(dst0 + 8, x_tmp02); STORE_4(dst1 + 0, x_tmp11); STORE_4(dst1 + 8, x_tmp12); STORE_4(dst2 + 0, x_tmp21); STORE_4(dst2 + 8, x_tmp22); STORE_4(dst3 + 0, x_tmp31); STORE_4(dst3 + 8, x_tmp32); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto w0 = _load_int4x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); w0 = _load_int4x4(weight + sy * 2, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); } else { auto t0 = LOAD4(dst + 8 * 0); auto t1 = LOAD4(dst + 8 * 1); z0 = _mm_add_ps(z0, t0); z1 = _mm_add_ps(z1, t1); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); } } } template <typename TYPE> static void _AVX_MNNPackednMatMulRemainCommon_int4(TYPE* C, const TYPE* A, const TYPE* fB, size_t eSize, const size_t* parameter, const float* k, const float* b) { auto B = reinterpret_cast<const uint8_t*>(fB); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); float weightBytes = 0.5; // sizeof(int4_t) auto bExtraStride = static_cast<int32_t>(parameter[5] / weightBytes); auto bStride = bExtraStride + 4 * l; auto hC4 = UP_DIV(h, 4); auto es = eSize; auto oC = C; auto aStride = parameter[0] / sizeof(TYPE); size_t blockId = parameter[6]; if (eSize >= 20) { _AVX_MNNPackedMatMul_int4_20<TYPE>(C, A, B, parameter, k, b); eSize -= 20; C += 20 * 8; A += 20; } if (eSize >= 16) { _AVX_MNNPackedMatMul_int4_16<TYPE>(C, A, B, parameter, k, b); eSize -= 16; C += 16 * 8; A += 16; } while (eSize >= 5) { _AVX_MNNPackedMatMul_int4_5<TYPE>(C, A, B, parameter, k, b); eSize -= 5; C += 5 * 8; A += 5; } if (eSize == 4) { _AVX_MNNPackedMatMul_int4_4<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 3) { _AVX_MNNPackedMatMul_int4_3<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 2) { _AVX_MNNPackedMatMul_int4_2<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 0) { return; } int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; int x = 0; for (int y = 0; y < hC16; ++y) { auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8; auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4; auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8; auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4; LOAD_WEIGHT_ALPHA_BIAS_int4x4 LOAD_ALPHA_BIAS_DOUBLE auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W00 = _load_int4x8(weight0 + 8 * sy + 0, alpha0_2, bias0_2); auto W01 = _load_int4x8(weight0 + 8 * sy + 4, alpha0_2, bias0_2); auto W10 = _load_int4x8(weight1 + 8 * sy + 0, alpha1_2, bias1_2); auto W11 = _load_int4x8(weight1 + 8 * sy + 4, alpha1_2, bias1_2); auto W20 = _load_int4x8(weight2 + 8 * sy + 0, alpha2_2, bias2_2); auto W21 = _load_int4x8(weight2 + 8 * sy + 4, alpha2_2, bias2_2); auto W30 = _load_int4x8(weight3 + 8 * sy + 0, alpha3_2, bias3_2); auto W31 = _load_int4x8(weight3 + 8 * sy + 4, alpha3_2, bias3_2); sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00); sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01); sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11); sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20); sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21); sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30); sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31); srcUse += 4 * aStride; } sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01); sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11); sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21); sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31); auto sum00 = _mm256_extractf128_ps(sumAvx00, 0); auto sum01 = _mm256_extractf128_ps(sumAvx00, 1); auto sum0 = _mm_add_ps(sum00, sum01); auto sum10 = _mm256_extractf128_ps(sumAvx10, 0); auto sum11 = _mm256_extractf128_ps(sumAvx10, 1); auto sum1 = _mm_add_ps(sum10, sum11); auto sum20 = _mm256_extractf128_ps(sumAvx20, 0); auto sum21 = _mm256_extractf128_ps(sumAvx20, 1); auto sum2 = _mm_add_ps(sum20, sum21); auto sum30 = _mm256_extractf128_ps(sumAvx30, 0); auto sum31 = _mm256_extractf128_ps(sumAvx30, 1); auto sum3 = _mm_add_ps(sum30, sum31); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w0 = _load_int4x4(weight0 + 2 * sy, alpha0, bias0); auto w1 = _load_int4x4(weight1 + 2 * sy, alpha1, bias1); auto w2 = _load_int4x4(weight2 + 2 * sy, alpha2, bias2); auto w3 = _load_int4x4(weight3 + 2 * sy, alpha3, bias3); sum0 = MNNSSEFMA(s, w0, sum0); sum1 = MNNSSEFMA(s, w1, sum1); sum2 = MNNSSEFMA(s, w2, sum2); sum3 = MNNSSEFMA(s, w3, sum3); srcUse += aStride; } if (blockId == 0) { STORE_4(dst0, sum0); STORE_4(dst1, sum1); STORE_4(dst2, sum2); STORE_4(dst3, sum3); } else { auto tmp_0 = LOAD4(dst0); auto tmp_1 = LOAD4(dst1); auto tmp_2 = LOAD4(dst2); auto tmp_3 = LOAD4(dst3); sum0 = _mm_add_ps(tmp_0, sum0); sum1 = _mm_add_ps(tmp_1, sum1); sum2 = _mm_add_ps(tmp_2, sum2); sum3 = _mm_add_ps(tmp_3, sum3); STORE_4(dst0, sum0); STORE_4(dst1, sum1); STORE_4(dst2, sum2); STORE_4(dst3, sum3); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride / 2; auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto alpha_2 = _mm256_set_m128(alpha, alpha); auto bias_2 = _mm256_set_m128(bias, bias); auto sumAvx0 = _mm256_setzero_ps(); auto sumAvx1 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W0 = _load_int4x8(weight + 8 * sy + 0, alpha_2, bias_2); auto W1 = _load_int4x8(weight + 8 * sy + 4, alpha_2, bias_2); sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0); sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1); srcUse += 4 * aStride; } sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1); auto sum0 = _mm256_extractf128_ps(sumAvx0, 0); auto sum1 = _mm256_extractf128_ps(sumAvx0, 1); auto sum = _mm_add_ps(sum0, sum1); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w = _load_int4x4(weight + sy * 2, alpha, bias); sum = MNNSSEFMA(s, w, sum); srcUse += aStride; } if (blockId == 0) { STORE_4(dst, sum); } else { auto tmp_0 = LOAD4(dst); sum = _mm_add_ps(tmp_0, sum); STORE_4(dst, sum); } } } //----------------------- MatMul(float, int8) Functions ---------------------------// #define LOAD_WEIGHT_ALPHA_BIAS_int8x4 \ auto weight0 = B + (hC4Unit * y + 0) * bStride;\ auto weight1 = B + (hC4Unit * y + 1) * bStride;\ auto weight2 = B + (hC4Unit * y + 2) * bStride;\ auto weight3 = B + (hC4Unit * y + 3) * bStride;\ auto alpha0 = _mm_loadu_ps(k + y * 16 + 0);\ auto alpha1 = _mm_loadu_ps(k + y * 16 + 4);\ auto alpha2 = _mm_loadu_ps(k + y * 16 + 8);\ auto alpha3 = _mm_loadu_ps(k + y * 16 + 12);\ auto bias0 = _mm_loadu_ps(b + y * 16 + 0);\ auto bias1 = _mm_loadu_ps(b + y * 16 + 4);\ auto bias2 = _mm_loadu_ps(b + y * 16 + 8);\ auto bias3 = _mm_loadu_ps(b + y * 16 + 12); static inline __m128 _load_int8x4(const int8_t* src, __m128 alpha, __m128 bias) { int iw0 = src[0]; int iw1 = src[1]; int iw2 = src[2]; int iw3 = src[3]; auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); return ws; } static inline __m256 _load_int8x8(const int8_t* src, __m256 alpha, __m256 bias) { float w[8]; for (int i = 0; i < 8; i++) { w[i] = int(src[i]); } auto w8 = LOAD8(w); return _mm256_add_ps(_mm256_mul_ps(w8, alpha), bias); } template <typename TYPE> static void _AVX_MNNPackedMatMul_Main_int8(TYPE* C, const TYPE* A, const TYPE* fB, const size_t* parameter, const float* k, const float* b) { auto B = reinterpret_cast<const int8_t*>(fB); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); int weightBytes = sizeof(int8_t); auto bExtraStride = parameter[5] / weightBytes; auto bStride = bExtraStride + 4 * l; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; size_t blockId = parameter[6]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * 24); auto s1 = LOAD8(A + 0 * 24 + 8); auto s2 = LOAD8(A + 0 * 24 + 16); auto ws = _load_int8x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z2 = _mm256_mul_ps(s2, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z5 = _mm256_mul_ps(s2, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z8 = _mm256_mul_ps(s2, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); auto z11 = _mm256_mul_ps(s2, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * 24); s1 = LOAD8(A + sy * 24 + 8); s2 = LOAD8(A + sy * 24 + 16); ws = _load_int8x4(weight + sy * 4, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z2 = MNNAVXFMA(s2, w0, z2); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z5 = MNNAVXFMA(s2, w1, z5); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z8 = MNNAVXFMA(s2, w2, z8); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); z11 = MNNAVXFMA(s2, w3, z11); } if (blockId == 0) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); TRANPOSE_SAVE(1, 2, z2, z5, z8, z11); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(0, 2, z2, z5, z8, z11); FMLA_TRANSPOSE_SAVE(1, 2, z2, z5, z8, z11); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_20(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); int weightBytes = sizeof(int8_t); auto bExtraStride = parameter[5] / weightBytes; auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * aStride); auto s1 = LOAD8(A + 0 * aStride + 8); auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); auto ws = _load_int8x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z2 = _mm256_mul_ps(s2, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z5 = _mm256_mul_ps(s2, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z8 = _mm256_mul_ps(s2, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); auto z11 = _mm256_mul_ps(s2, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * aStride); s1 = LOAD8(A + sy * aStride + 8); s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); ws = _load_int8x4(weight + sy * 4, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z2 = MNNAVXFMA(s2, w0, z2); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z5 = MNNAVXFMA(s2, w1, z5); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z8 = MNNAVXFMA(s2, w2, z8); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); z11 = MNNAVXFMA(s2, w3, z11); } if (0 == blockId) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(0, 2, z2, z5, z8, z11); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_16(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); float ws_tmp[4]; for (int y = 0; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD8(A + 0 * aStride); auto s1 = LOAD8(A + 0 * aStride + 8); auto ws = _load_int8x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm256_set1_ps(ws_tmp[0]); auto w1 = _mm256_set1_ps(ws_tmp[1]); auto w2 = _mm256_set1_ps(ws_tmp[2]); auto w3 = _mm256_set1_ps(ws_tmp[3]); auto z0 = _mm256_mul_ps(s0, w0); auto z1 = _mm256_mul_ps(s1, w0); auto z3 = _mm256_mul_ps(s0, w1); auto z4 = _mm256_mul_ps(s1, w1); auto z6 = _mm256_mul_ps(s0, w2); auto z7 = _mm256_mul_ps(s1, w2); auto z9 = _mm256_mul_ps(s0, w3); auto z10 = _mm256_mul_ps(s1, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD8(A + sy * aStride); s1 = LOAD8(A + sy * aStride + 8); ws = _load_int8x4(weight + sy * 4, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm256_set1_ps(ws_tmp[0]); w1 = _mm256_set1_ps(ws_tmp[1]); w2 = _mm256_set1_ps(ws_tmp[2]); w3 = _mm256_set1_ps(ws_tmp[3]); z0 = MNNAVXFMA(s0, w0, z0); z1 = MNNAVXFMA(s1, w0, z1); z3 = MNNAVXFMA(s0, w1, z3); z4 = MNNAVXFMA(s1, w1, z4); z6 = MNNAVXFMA(s0, w2, z6); z7 = MNNAVXFMA(s1, w2, z7); z9 = MNNAVXFMA(s0, w3, z9); z10 = MNNAVXFMA(s1, w3, z10); } if (0 == blockId) { TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); } else { FMLA_TRANSPOSE_SAVE(0, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(1, 0, z0, z3, z6, z9); FMLA_TRANSPOSE_SAVE(0, 1, z1, z4, z7, z10); FMLA_TRANSPOSE_SAVE(1, 1, z1, z4, z7, z10); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_5(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int8x4 DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto sumAvx40 = _mm256_setzero_ps(); auto sumAvx41 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto S4 = BROAD_LOAD(srcUse + 4); auto w0 = _load_int8x4(weight0, alpha0, bias0); auto w1 = _load_int8x4(weight1, alpha1, bias1); auto w2 = _load_int8x4(weight2, alpha2, bias2); auto w3 = _load_int8x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40); sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } if (0 == blockId) { STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst0 + 32, sumAvx40); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); STORE_8(dst2 + 32, sumAvx41); } else { auto tmp0 = LOAD8(dst0); auto tmp1 = LOAD8(dst0 + 8); auto tmp2 = LOAD8(dst0 + 16); auto tmp3 = LOAD8(dst0 + 24); auto tmp4 = LOAD8(dst0 + 32); auto tmp5 = LOAD8(dst2); auto tmp6 = LOAD8(dst2 + 8); auto tmp7 = LOAD8(dst2 + 16); auto tmp8 = LOAD8(dst2 + 24); auto tmp9 = LOAD8(dst2 + 32); sumAvx00 = _mm256_add_ps(sumAvx00, tmp0); sumAvx10 = _mm256_add_ps(sumAvx10, tmp1); sumAvx20 = _mm256_add_ps(sumAvx20, tmp2); sumAvx30 = _mm256_add_ps(sumAvx30, tmp3); sumAvx40 = _mm256_add_ps(sumAvx40, tmp4); sumAvx01 = _mm256_add_ps(sumAvx01, tmp5); sumAvx11 = _mm256_add_ps(sumAvx11, tmp6); sumAvx21 = _mm256_add_ps(sumAvx21, tmp7); sumAvx31 = _mm256_add_ps(sumAvx31, tmp8); sumAvx41 = _mm256_add_ps(sumAvx41, tmp9); STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst0 + 32, sumAvx40); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); STORE_8(dst2 + 32, sumAvx41); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3); auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4); auto w0 = _load_int8x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); auto z3 = _mm_mul_ps(s3, w0); auto z4 = _mm_mul_ps(s4, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); s3 = BROAD_LOAD_4(A + sy * aStride + 3); s4 = BROAD_LOAD_4(A + sy * aStride + 4); w0 = _load_int8x4(weight + sy * 4, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); z3 = MNNSSEFMA(s3, w0, z3); z4 = MNNSSEFMA(s4, w0, z4); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); STORE_4(dst + 8 * 3, z3); STORE_4(dst + 8 * 4, z4); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); auto tmp3 = LOAD4(dst + 8 * 3); auto tmp4 = LOAD4(dst + 8 * 4); z0 = _mm_add_ps(tmp0, z0); z1 = _mm_add_ps(tmp1, z1); z2 = _mm_add_ps(tmp2, z2); z3 = _mm_add_ps(tmp3, z3); z4 = _mm_add_ps(tmp4, z4); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); STORE_4(dst + 8 * 3, z3); STORE_4(dst + 8 * 4, z4); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_4(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int8x4 DST_ADDR_UNPACK4(0); auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto S3 = BROAD_LOAD(srcUse + 3); auto w0 = _load_int8x4(weight0, alpha0, bias0); auto w1 = _load_int8x4(weight1, alpha1, bias1); auto w2 = _load_int8x4(weight2, alpha2, bias2); auto w3 = _load_int8x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } if (0 == blockId) { STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); } else { auto tmp0 = LOAD8(dst0); auto tmp1 = LOAD8(dst0 + 8); auto tmp2 = LOAD8(dst0 + 16); auto tmp3 = LOAD8(dst0 + 24); auto tmp5 = LOAD8(dst2); auto tmp6 = LOAD8(dst2 + 8); auto tmp7 = LOAD8(dst2 + 16); auto tmp8 = LOAD8(dst2 + 24); sumAvx00 = _mm256_add_ps(sumAvx00, tmp0); sumAvx10 = _mm256_add_ps(sumAvx10, tmp1); sumAvx20 = _mm256_add_ps(sumAvx20, tmp2); sumAvx30 = _mm256_add_ps(sumAvx30, tmp3); sumAvx01 = _mm256_add_ps(sumAvx01, tmp5); sumAvx11 = _mm256_add_ps(sumAvx11, tmp6); sumAvx21 = _mm256_add_ps(sumAvx21, tmp7); sumAvx31 = _mm256_add_ps(sumAvx31, tmp8); STORE_8(dst0, sumAvx00); STORE_8(dst0 + 8, sumAvx10); STORE_8(dst0 + 16, sumAvx20); STORE_8(dst0 + 24, sumAvx30); STORE_8(dst2, sumAvx01); STORE_8(dst2 + 8, sumAvx11); STORE_8(dst2 + 16, sumAvx21); STORE_8(dst2 + 24, sumAvx31); } } float ws_tmp[4]; for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = LOAD4(A + 0 * aStride); auto ws = _load_int8x4(weight, alpha, bias); _mm_storeu_ps(ws_tmp, ws); auto w0 = _mm_set1_ps(ws_tmp[0]); auto w1 = _mm_set1_ps(ws_tmp[1]); auto w2 = _mm_set1_ps(ws_tmp[2]); auto w3 = _mm_set1_ps(ws_tmp[3]); auto z0 = _mm_mul_ps(s0, w0); auto z3 = _mm_mul_ps(s0, w1); auto z6 = _mm_mul_ps(s0, w2); auto z9 = _mm_mul_ps(s0, w3); for (int sy = 1; sy < l; ++sy) { s0 = LOAD4(A + sy * aStride); ws = _load_int8x4(weight + sy * 4, alpha, bias); _mm_storeu_ps(ws_tmp, ws); w0 = _mm_set1_ps(ws_tmp[0]); w1 = _mm_set1_ps(ws_tmp[1]); w2 = _mm_set1_ps(ws_tmp[2]); w3 = _mm_set1_ps(ws_tmp[3]); z0 = MNNSSEFMA(s0, w0, z0); z3 = MNNSSEFMA(s0, w1, z3); z6 = MNNSSEFMA(s0, w2, z6); z9 = MNNSSEFMA(s0, w3, z9); } _MM_TRANSPOSE4_PS(z0, z3, z6, z9); if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z3); STORE_4(dst + 8 * 2, z6); STORE_4(dst + 8 * 3, z9); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); auto tmp3 = LOAD4(dst + 8 * 3); z0 = _mm_add_ps(tmp0, z0); z3 = _mm_add_ps(tmp1, z3); z6 = _mm_add_ps(tmp2, z6); z9 = _mm_add_ps(tmp3, z9); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z3); STORE_4(dst + 8 * 2, z6); STORE_4(dst + 8 * 3, z9); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_3(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int8x4 auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto S2 = BROAD_LOAD(srcUse + 2); auto w0 = _load_int8x4(weight0, alpha0, bias0); auto w1 = _load_int8x4(weight1, alpha1, bias1); auto w2 = _load_int8x4(weight2, alpha2, bias2); auto w3 = _load_int8x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } if (0 == blockId) { STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1)); } else { auto tmp00 = LOAD4(dst0 + 0); auto tmp01 = LOAD4(dst0 + 8); auto tmp02 = LOAD4(dst0 + 16); auto tmp10 = LOAD4(dst1 + 0); auto tmp11 = LOAD4(dst1 + 8); auto tmp12 = LOAD4(dst1 + 16); auto tmp20 = LOAD4(dst2 + 0); auto tmp21 = LOAD4(dst2 + 8); auto tmp22 = LOAD4(dst2 + 16); auto tmp30 = LOAD4(dst3 + 0); auto tmp31 = LOAD4(dst3 + 8); auto tmp32 = LOAD4(dst3 + 16); auto sum_tmp00 = _mm256_extractf128_ps(sumAvx00, 0); auto sum_tmp01 = _mm256_extractf128_ps(sumAvx10, 0); auto sum_tmp02 = _mm256_extractf128_ps(sumAvx20, 0); auto sum_tmp10 = _mm256_extractf128_ps(sumAvx00, 1); auto sum_tmp11 = _mm256_extractf128_ps(sumAvx10, 1); auto sum_tmp12 = _mm256_extractf128_ps(sumAvx20, 1); auto sum_tmp20 = _mm256_extractf128_ps(sumAvx01, 0); auto sum_tmp21 = _mm256_extractf128_ps(sumAvx11, 0); auto sum_tmp22 = _mm256_extractf128_ps(sumAvx21, 0); auto sum_tmp30 = _mm256_extractf128_ps(sumAvx01, 1); auto sum_tmp31 = _mm256_extractf128_ps(sumAvx11, 1); auto sum_tmp32 = _mm256_extractf128_ps(sumAvx21, 1); sum_tmp00 = _mm_add_ps(tmp00, sum_tmp00); sum_tmp01 = _mm_add_ps(tmp01, sum_tmp01); sum_tmp02 = _mm_add_ps(tmp02, sum_tmp02); sum_tmp10 = _mm_add_ps(tmp10, sum_tmp10); sum_tmp11 = _mm_add_ps(tmp11, sum_tmp11); sum_tmp12 = _mm_add_ps(tmp12, sum_tmp12); sum_tmp20 = _mm_add_ps(tmp20, sum_tmp20); sum_tmp21 = _mm_add_ps(tmp21, sum_tmp21); sum_tmp22 = _mm_add_ps(tmp22, sum_tmp22); sum_tmp30 = _mm_add_ps(tmp30, sum_tmp30); sum_tmp31 = _mm_add_ps(tmp31, sum_tmp31); sum_tmp32 = _mm_add_ps(tmp32, sum_tmp32); STORE_4(dst0 + 0, sum_tmp00); STORE_4(dst0 + 8, sum_tmp01); STORE_4(dst0 + 16, sum_tmp02); STORE_4(dst1 + 0, sum_tmp10); STORE_4(dst1 + 8, sum_tmp11); STORE_4(dst1 + 16, sum_tmp12); STORE_4(dst2 + 0, sum_tmp20); STORE_4(dst2 + 8, sum_tmp21); STORE_4(dst2 + 16, sum_tmp22); STORE_4(dst3 + 0, sum_tmp30); STORE_4(dst3 + 8, sum_tmp31); STORE_4(dst3 + 16, sum_tmp32); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); auto w0 = _load_int8x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); auto z2 = _mm_mul_ps(s2, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); s2 = BROAD_LOAD_4(A + sy * aStride + 2); w0 = _load_int8x4(weight + sy * 4, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); z2 = MNNSSEFMA(s2, w0, z2); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); } else { auto tmp0 = LOAD4(dst + 8 * 0); auto tmp1 = LOAD4(dst + 8 * 1); auto tmp2 = LOAD4(dst + 8 * 2); z0 = _mm_add_ps(tmp0, z0); z1 = _mm_add_ps(tmp1, z1); z2 = _mm_add_ps(tmp2, z2); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); STORE_4(dst + 8 * 2, z2); } } } template <typename TYPE> static void _AVX_MNNPackedMatMul_int8_2(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(TYPE); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; for (int y = 0; y < hC16; ++y) { LOAD_WEIGHT_ALPHA_BIAS_int8x4 auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); DST_ADDR_UNPACK4(0); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < l; ++sy) { auto S0 = BROAD_LOAD(srcUse + 0); auto S1 = BROAD_LOAD(srcUse + 1); auto w0 = _load_int8x4(weight0, alpha0, bias0); auto w1 = _load_int8x4(weight1, alpha1, bias1); auto w2 = _load_int8x4(weight2, alpha2, bias2); auto w3 = _load_int8x4(weight3, alpha3, bias3); auto W0 = _mm256_set_m128(w1, w0); auto W1 = _mm256_set_m128(w3, w2); sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); srcUse += aStride; weight0 += 4; weight1 += 4; weight2 += 4; weight3 += 4; } if (0 == blockId) { STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); } else { auto tmp01 = LOAD4(dst0 + 0); auto tmp02 = LOAD4(dst0 + 8); auto tmp11 = LOAD4(dst1 + 0); auto tmp12 = LOAD4(dst1 + 8); auto tmp21 = LOAD4(dst2 + 0); auto tmp22 = LOAD4(dst2 + 8); auto tmp31 = LOAD4(dst3 + 0); auto tmp32 = LOAD4(dst3 + 8); auto x_tmp01 = _mm256_extractf128_ps(sumAvx00, 0); auto x_tmp02 = _mm256_extractf128_ps(sumAvx10, 0); auto x_tmp11 = _mm256_extractf128_ps(sumAvx00, 1); auto x_tmp12 = _mm256_extractf128_ps(sumAvx10, 1); auto x_tmp21 = _mm256_extractf128_ps(sumAvx01, 0); auto x_tmp22 = _mm256_extractf128_ps(sumAvx11, 0); auto x_tmp31 = _mm256_extractf128_ps(sumAvx01, 1); auto x_tmp32 = _mm256_extractf128_ps(sumAvx11, 1); x_tmp01 = _mm_add_ps(tmp01, x_tmp01); x_tmp02 = _mm_add_ps(tmp02, x_tmp02); x_tmp11 = _mm_add_ps(tmp11, x_tmp11); x_tmp12 = _mm_add_ps(tmp12, x_tmp12); x_tmp21 = _mm_add_ps(tmp21, x_tmp21); x_tmp22 = _mm_add_ps(tmp22, x_tmp22); x_tmp31 = _mm_add_ps(tmp31, x_tmp31); x_tmp32 = _mm_add_ps(tmp32, x_tmp32); STORE_4(dst0 + 0, x_tmp01); STORE_4(dst0 + 8, x_tmp02); STORE_4(dst1 + 0, x_tmp11); STORE_4(dst1 + 8, x_tmp12); STORE_4(dst2 + 0, x_tmp21); STORE_4(dst2 + 8, x_tmp22); STORE_4(dst3 + 0, x_tmp31); STORE_4(dst3 + 8, x_tmp32); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); auto w0 = _load_int8x4(weight, alpha, bias); auto z0 = _mm_mul_ps(s0, w0); auto z1 = _mm_mul_ps(s1, w0); for (int sy = 1; sy < l; ++sy) { s0 = BROAD_LOAD_4(A + sy * aStride + 0); s1 = BROAD_LOAD_4(A + sy * aStride + 1); w0 = _load_int8x4(weight + sy * 4, alpha, bias); z0 = MNNSSEFMA(s0, w0, z0); z1 = MNNSSEFMA(s1, w0, z1); } if (0 == blockId) { STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); } else { auto t0 = LOAD4(dst + 8 * 0); auto t1 = LOAD4(dst + 8 * 1); z0 = _mm_add_ps(z0, t0); z1 = _mm_add_ps(z1, t1); STORE_4(dst + 8 * 0, z0); STORE_4(dst + 8 * 1, z1); } } } template <typename TYPE> static void _AVX_MNNPackednMatMulRemainCommon_int8(TYPE* C, const TYPE* A, const TYPE* fB, size_t eSize, const size_t* parameter, const float* k, const float* b) { auto B = reinterpret_cast<const int8_t*>(fB); auto h = parameter[2]; auto l = parameter[1]; auto cStride = parameter[3] / sizeof(TYPE); auto bExtraStride = parameter[5] / sizeof(int8_t); auto bStride = bExtraStride + 4 * l; auto blockId = parameter[6]; auto hC4 = UP_DIV(h, 4); auto es = eSize; auto oC = C; auto aStride = parameter[0] / sizeof(TYPE); if (eSize >= 20) { _AVX_MNNPackedMatMul_int8_20<TYPE>(C, A, B, parameter, k, b); eSize -= 20; C += 20 * 8; A += 20; } if (eSize >= 16) { _AVX_MNNPackedMatMul_int8_16<TYPE>(C, A, B, parameter, k, b); eSize -= 16; C += 16 * 8; A += 16; } while (eSize >= 5) { _AVX_MNNPackedMatMul_int8_5<TYPE>(C, A, B, parameter, k, b); eSize -= 5; C += 5 * 8; A += 5; } if (eSize == 4) { _AVX_MNNPackedMatMul_int8_4<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 3) { _AVX_MNNPackedMatMul_int8_3<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 2) { _AVX_MNNPackedMatMul_int8_2<TYPE>(C, A, B, parameter, k, b); return; } if (eSize == 0) { return; } int lC4 = l / 4; int lR = lC4 * 4; const int hC4Unit = 4; int hC16 = hC4 / hC4Unit; int hR = hC16 * hC4Unit; auto src = A; int x = 0; for (int y = 0; y < hC16; ++y) { auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8; auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4; auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8; auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4; LOAD_WEIGHT_ALPHA_BIAS_int8x4 LOAD_ALPHA_BIAS_DOUBLE auto sumAvx00 = _mm256_setzero_ps(); auto sumAvx01 = _mm256_setzero_ps(); auto sumAvx10 = _mm256_setzero_ps(); auto sumAvx11 = _mm256_setzero_ps(); auto sumAvx20 = _mm256_setzero_ps(); auto sumAvx21 = _mm256_setzero_ps(); auto sumAvx30 = _mm256_setzero_ps(); auto sumAvx31 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W00 = _load_int8x8(weight0 + 16 * sy + 0, alpha0_2, bias0_2); auto W01 = _load_int8x8(weight0 + 16 * sy + 8, alpha0_2, bias0_2); auto W10 = _load_int8x8(weight1 + 16 * sy + 0, alpha1_2, bias1_2); auto W11 = _load_int8x8(weight1 + 16 * sy + 8, alpha1_2, bias1_2); auto W20 = _load_int8x8(weight2 + 16 * sy + 0, alpha2_2, bias2_2); auto W21 = _load_int8x8(weight2 + 16 * sy + 8, alpha2_2, bias2_2); auto W30 = _load_int8x8(weight3 + 16 * sy + 0, alpha3_2, bias3_2); auto W31 = _load_int8x8(weight3 + 16 * sy + 8, alpha3_2, bias3_2); sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00); sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01); sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10); sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11); sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20); sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21); sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30); sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31); srcUse += 4 * aStride; } sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01); sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11); sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21); sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31); auto sum00 = _mm256_extractf128_ps(sumAvx00, 0); auto sum01 = _mm256_extractf128_ps(sumAvx00, 1); auto sum0 = _mm_add_ps(sum00, sum01); auto sum10 = _mm256_extractf128_ps(sumAvx10, 0); auto sum11 = _mm256_extractf128_ps(sumAvx10, 1); auto sum1 = _mm_add_ps(sum10, sum11); auto sum20 = _mm256_extractf128_ps(sumAvx20, 0); auto sum21 = _mm256_extractf128_ps(sumAvx20, 1); auto sum2 = _mm_add_ps(sum20, sum21); auto sum30 = _mm256_extractf128_ps(sumAvx30, 0); auto sum31 = _mm256_extractf128_ps(sumAvx30, 1); auto sum3 = _mm_add_ps(sum30, sum31); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w0 = _load_int8x4(weight0 + 4 * sy, alpha0, bias0); auto w1 = _load_int8x4(weight1 + 4 * sy, alpha1, bias1); auto w2 = _load_int8x4(weight2 + 4 * sy, alpha2, bias2); auto w3 = _load_int8x4(weight3 + 4 * sy, alpha3, bias3); sum0 = MNNSSEFMA(s, w0, sum0); sum1 = MNNSSEFMA(s, w1, sum1); sum2 = MNNSSEFMA(s, w2, sum2); sum3 = MNNSSEFMA(s, w3, sum3); srcUse += aStride; } if (blockId == 0) { STORE_4(dst0, sum0); STORE_4(dst1, sum1); STORE_4(dst2, sum2); STORE_4(dst3, sum3); } else { auto tmp_0 = LOAD4(dst0); auto tmp_1 = LOAD4(dst1); auto tmp_2 = LOAD4(dst2); auto tmp_3 = LOAD4(dst3); sum0 = _mm_add_ps(tmp_0, sum0); sum1 = _mm_add_ps(tmp_1, sum1); sum2 = _mm_add_ps(tmp_2, sum2); sum3 = _mm_add_ps(tmp_3, sum3); STORE_4(dst0, sum0); STORE_4(dst1, sum1); STORE_4(dst2, sum2); STORE_4(dst3, sum3); } } for (int y = hR; y < hC4; ++y) { auto weight = B + y * bStride; auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2); auto alpha = _mm_loadu_ps(k + y * 4); auto bias = _mm_loadu_ps(b + y * 4); auto alpha_2 = _mm256_set_m128(alpha, alpha); auto bias_2 = _mm256_set_m128(bias, bias); auto sumAvx0 = _mm256_setzero_ps(); auto sumAvx1 = _mm256_setzero_ps(); auto srcUse = src; for (int sy = 0; sy < lC4; ++sy) { auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); auto W0 = _load_int8x8(weight + 16 * sy + 0, alpha_2, bias_2); auto W1 = _load_int8x8(weight + 16 * sy + 8, alpha_2, bias_2); sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0); sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1); srcUse += 4 * aStride; } sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1); auto sum0 = _mm256_extractf128_ps(sumAvx0, 0); auto sum1 = _mm256_extractf128_ps(sumAvx0, 1); auto sum = _mm_add_ps(sum0, sum1); for (int sy = lR; sy < l; ++sy) { auto s = BROAD_LOAD_4(srcUse); auto w = _load_int8x4(weight + sy * 4, alpha, bias); sum = MNNSSEFMA(s, w, sum); srcUse += aStride; } if (blockId == 0) { STORE_4(dst, sum); } else { auto tmp_0 = LOAD4(dst); sum = _mm_add_ps(tmp_0, sum); STORE_4(dst, sum); } } } #endif