in source/backend/cpu/x86_x64/avx512/Gemm31_16.h [12:241]
void _AVX512_MNNPackedMatMulO16FullLoadKernel(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias) {
#define REDUCE_MUL_ADD(ick) \
zmm0 = _mm512_loadu_ps(filterICPtr + ick * bStride); \
if (InputTile > 8) \
_mm_prefetch(filterICPtr + ick * bStride + AVX512_PACK_C_UNIT * AVX512_PACK_C_UNIT, _MM_HINT_T0); \
if (InputTile > 0) \
zmm1 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 0 * AVX512_PACK_C_UNIT]), zmm0, zmm1); \
if (InputTile > 1) \
zmm2 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 1 * AVX512_PACK_C_UNIT]), zmm0, zmm2); \
if (InputTile > 2) \
zmm3 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 2 * AVX512_PACK_C_UNIT]), zmm0, zmm3); \
if (InputTile > 3) \
zmm4 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 3 * AVX512_PACK_C_UNIT]), zmm0, zmm4); \
if (InputTile > 4) \
zmm5 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 4 * AVX512_PACK_C_UNIT]), zmm0, zmm5); \
if (InputTile > 5) \
zmm6 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 5 * AVX512_PACK_C_UNIT]), zmm0, zmm6); \
if (InputTile > 6) \
zmm7 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 6 * AVX512_PACK_C_UNIT]), zmm0, zmm7); \
if (InputTile > 7) \
zmm8 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 7 * AVX512_PACK_C_UNIT]), zmm0, zmm8); \
if (InputTile > 8) \
zmm9 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 8 * AVX512_PACK_C_UNIT]), zmm0, zmm9); \
if (InputTile > 9) \
zmm10 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 9 * AVX512_PACK_C_UNIT]), zmm0, zmm10); \
if (InputTile > 10) \
zmm11 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 10 * AVX512_PACK_C_UNIT]), zmm0, zmm11); \
if (InputTile > 11) \
zmm12 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 11 * AVX512_PACK_C_UNIT]), zmm0, zmm12); \
if (InputTile > 12) \
zmm13 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 12 * AVX512_PACK_C_UNIT]), zmm0, zmm13); \
if (InputTile > 13) \
zmm14 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 13 * AVX512_PACK_C_UNIT]), zmm0, zmm14); \
if (InputTile > 14) \
zmm15 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 14 * AVX512_PACK_C_UNIT]), zmm0, zmm15); \
if (InputTile > 15) \
zmm16 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 15 * AVX512_PACK_C_UNIT]), zmm0, zmm16); \
if (InputTile > 16) \
zmm17 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 16 * AVX512_PACK_C_UNIT]), zmm0, zmm17); \
if (InputTile > 17) \
zmm18 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 17 * AVX512_PACK_C_UNIT]), zmm0, zmm18); \
if (InputTile > 18) \
zmm19 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 18 * AVX512_PACK_C_UNIT]), zmm0, zmm19); \
if (InputTile > 19) \
zmm20 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 19 * AVX512_PACK_C_UNIT]), zmm0, zmm20); \
if (InputTile > 20) \
zmm21 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 20 * AVX512_PACK_C_UNIT]), zmm0, zmm21); \
if (InputTile > 21) \
zmm22 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 21 * AVX512_PACK_C_UNIT]), zmm0, zmm22); \
if (InputTile > 22) \
zmm23 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 22 * AVX512_PACK_C_UNIT]), zmm0, zmm23); \
if (InputTile > 23) \
zmm24 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 23 * AVX512_PACK_C_UNIT]), zmm0, zmm24); \
if (InputTile > 24) \
zmm25 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 24 * AVX512_PACK_C_UNIT]), zmm0, zmm25); \
if (InputTile > 25) \
zmm26 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 25 * AVX512_PACK_C_UNIT]), zmm0, zmm26); \
if (InputTile > 26) \
zmm27 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 26 * AVX512_PACK_C_UNIT]), zmm0, zmm27); \
if (InputTile > 27) \
zmm28 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 27 * AVX512_PACK_C_UNIT]), zmm0, zmm28); \
if (InputTile > 28) \
zmm29 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 28 * AVX512_PACK_C_UNIT]), zmm0, zmm29); \
if (InputTile > 29) \
zmm30 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 29 * AVX512_PACK_C_UNIT]), zmm0, zmm30); \
if (InputTile > 30) \
zmm31 = _mm512_fmadd_ps(_mm512_set1_ps(inputICPtr[(ick) + 30 * AVX512_PACK_C_UNIT]), zmm0, zmm31);
auto aStride = parameter[0] / sizeof(float);
auto l = parameter[1];
auto h = parameter[2];
auto cStride = parameter[3] / sizeof(float);
auto bStride = parameter[5] / sizeof(float);
int aTotal = parameter[6];
auto icTail = l % AVX512_PACK_C_UNIT;
auto icPack = l - icTail;
auto inputTilePtr = A;
auto destPtr = C;
__m512 zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15;
__m512 zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31;
for(; aTotal > 0; aTotal -= InputTile) {
auto inputPtr = inputTilePtr;
auto filterPtr = B;
auto biasPtr = bias;
if (biasPtr) {
if (InputTile > 0 ) zmm1 = _mm512_loadu_ps(biasPtr);
if (InputTile > 1 ) zmm2 = zmm1;
if (InputTile > 2 ) zmm3 = zmm1;
if (InputTile > 3 ) zmm4 = zmm1;
if (InputTile > 4 ) zmm5 = zmm1;
if (InputTile > 5 ) zmm6 = zmm1;
if (InputTile > 6 ) zmm7 = zmm1;
if (InputTile > 7 ) zmm8 = zmm1;
if (InputTile > 8 ) zmm9 = zmm1;
if (InputTile > 9 ) zmm10 = zmm1;
if (InputTile > 10) zmm11 = zmm1;
if (InputTile > 11) zmm12 = zmm1;
if (InputTile > 12) zmm13 = zmm1;
if (InputTile > 13) zmm14 = zmm1;
if (InputTile > 14) zmm15 = zmm1;
if (InputTile > 15) zmm16 = zmm1;
if (InputTile > 16) zmm17 = zmm1;
if (InputTile > 17) zmm18 = zmm1;
if (InputTile > 18) zmm19 = zmm1;
if (InputTile > 19) zmm20 = zmm1;
if (InputTile > 20) zmm21 = zmm1;
if (InputTile > 21) zmm22 = zmm1;
if (InputTile > 22) zmm23 = zmm1;
if (InputTile > 23) zmm24 = zmm1;
if (InputTile > 24) zmm25 = zmm1;
if (InputTile > 25) zmm26 = zmm1;
if (InputTile > 26) zmm27 = zmm1;
if (InputTile > 27) zmm28 = zmm1;
if (InputTile > 28) zmm29 = zmm1;
if (InputTile > 29) zmm30 = zmm1;
if (InputTile > 30) zmm31 = zmm1;
} else {
if (InputTile > 0 ) zmm1 = _mm512_setzero_ps();
if (InputTile > 1 ) zmm2 = _mm512_setzero_ps();
if (InputTile > 2 ) zmm3 = _mm512_setzero_ps();
if (InputTile > 3 ) zmm4 = _mm512_setzero_ps();
if (InputTile > 4 ) zmm5 = _mm512_setzero_ps();
if (InputTile > 5 ) zmm6 = _mm512_setzero_ps();
if (InputTile > 6 ) zmm7 = _mm512_setzero_ps();
if (InputTile > 7 ) zmm8 = _mm512_setzero_ps();
if (InputTile > 8 ) zmm9 = _mm512_setzero_ps();
if (InputTile > 9 ) zmm10 = _mm512_setzero_ps();
if (InputTile > 10) zmm11 = _mm512_setzero_ps();
if (InputTile > 11) zmm12 = _mm512_setzero_ps();
if (InputTile > 12) zmm13 = _mm512_setzero_ps();
if (InputTile > 13) zmm14 = _mm512_setzero_ps();
if (InputTile > 14) zmm15 = _mm512_setzero_ps();
if (InputTile > 15) zmm16 = _mm512_setzero_ps();
if (InputTile > 16) zmm17 = _mm512_setzero_ps();
if (InputTile > 17) zmm18 = _mm512_setzero_ps();
if (InputTile > 18) zmm19 = _mm512_setzero_ps();
if (InputTile > 19) zmm20 = _mm512_setzero_ps();
if (InputTile > 20) zmm21 = _mm512_setzero_ps();
if (InputTile > 21) zmm22 = _mm512_setzero_ps();
if (InputTile > 22) zmm23 = _mm512_setzero_ps();
if (InputTile > 23) zmm24 = _mm512_setzero_ps();
if (InputTile > 24) zmm25 = _mm512_setzero_ps();
if (InputTile > 25) zmm26 = _mm512_setzero_ps();
if (InputTile > 26) zmm27 = _mm512_setzero_ps();
if (InputTile > 27) zmm28 = _mm512_setzero_ps();
if (InputTile > 28) zmm29 = _mm512_setzero_ps();
if (InputTile > 29) zmm30 = _mm512_setzero_ps();
if (InputTile > 30) zmm31 = _mm512_setzero_ps();
}
for(int il = 0; il < icPack; il += AVX512_PACK_C_UNIT) {
auto inputICPtr = inputPtr;
auto filterICPtr = filterPtr;
// REDUCE_MUL_ADD(0 );
// REDUCE_MUL_ADD(1 );
// REDUCE_MUL_ADD(2 );
// REDUCE_MUL_ADD(3 );
// REDUCE_MUL_ADD(4 );
// REDUCE_MUL_ADD(5 );
// REDUCE_MUL_ADD(6 );
// REDUCE_MUL_ADD(7 );
// REDUCE_MUL_ADD(8 );
// REDUCE_MUL_ADD(9 );
// REDUCE_MUL_ADD(10);
// REDUCE_MUL_ADD(11);
// REDUCE_MUL_ADD(12);
// REDUCE_MUL_ADD(13);
// REDUCE_MUL_ADD(14);
// REDUCE_MUL_ADD(15);
for (int ick = 0; ick < AVX512_PACK_C_UNIT; ++ick) {
REDUCE_MUL_ADD(ick);
}
inputPtr += InputTile * AVX512_PACK_C_UNIT;
filterPtr += bStride * AVX512_PACK_C_UNIT;
}
auto inputICPtr = inputPtr;
auto filterICPtr = filterPtr;
float out[16] = {0};
for(int ick = 0; ick < icTail; ++ick) {
REDUCE_MUL_ADD(ick);
}
// write
// oc < 16;
if (InputTile > 0 ) _mm512_storeu_ps(destPtr + 0 * AVX512_PACK_C_UNIT, zmm1 );
if (InputTile > 1 ) _mm512_storeu_ps(destPtr + 1 * AVX512_PACK_C_UNIT, zmm2 );
if (InputTile > 2 ) _mm512_storeu_ps(destPtr + 2 * AVX512_PACK_C_UNIT, zmm3 );
if (InputTile > 3 ) _mm512_storeu_ps(destPtr + 3 * AVX512_PACK_C_UNIT, zmm4 );
if (InputTile > 4 ) _mm512_storeu_ps(destPtr + 4 * AVX512_PACK_C_UNIT, zmm5 );
if (InputTile > 5 ) _mm512_storeu_ps(destPtr + 5 * AVX512_PACK_C_UNIT, zmm6 );
if (InputTile > 6 ) _mm512_storeu_ps(destPtr + 6 * AVX512_PACK_C_UNIT, zmm7 );
if (InputTile > 7 ) _mm512_storeu_ps(destPtr + 7 * AVX512_PACK_C_UNIT, zmm8 );
if (InputTile > 8 ) _mm512_storeu_ps(destPtr + 8 * AVX512_PACK_C_UNIT, zmm9 );
if (InputTile > 9 ) _mm512_storeu_ps(destPtr + 9 * AVX512_PACK_C_UNIT, zmm10);
if (InputTile > 10) _mm512_storeu_ps(destPtr + 10 * AVX512_PACK_C_UNIT, zmm11);
if (InputTile > 11) _mm512_storeu_ps(destPtr + 11 * AVX512_PACK_C_UNIT, zmm12);
if (InputTile > 12) _mm512_storeu_ps(destPtr + 12 * AVX512_PACK_C_UNIT, zmm13);
if (InputTile > 13) _mm512_storeu_ps(destPtr + 13 * AVX512_PACK_C_UNIT, zmm14);
if (InputTile > 14) _mm512_storeu_ps(destPtr + 14 * AVX512_PACK_C_UNIT, zmm15);
if (InputTile > 15) _mm512_storeu_ps(destPtr + 15 * AVX512_PACK_C_UNIT, zmm16);
if (InputTile > 16) _mm512_storeu_ps(destPtr + 16 * AVX512_PACK_C_UNIT, zmm17);
if (InputTile > 17) _mm512_storeu_ps(destPtr + 17 * AVX512_PACK_C_UNIT, zmm18);
if (InputTile > 18) _mm512_storeu_ps(destPtr + 18 * AVX512_PACK_C_UNIT, zmm19);
if (InputTile > 19) _mm512_storeu_ps(destPtr + 19 * AVX512_PACK_C_UNIT, zmm20);
if (InputTile > 20) _mm512_storeu_ps(destPtr + 20 * AVX512_PACK_C_UNIT, zmm21);
if (InputTile > 21) _mm512_storeu_ps(destPtr + 21 * AVX512_PACK_C_UNIT, zmm22);
if (InputTile > 22) _mm512_storeu_ps(destPtr + 22 * AVX512_PACK_C_UNIT, zmm23);
if (InputTile > 23) _mm512_storeu_ps(destPtr + 23 * AVX512_PACK_C_UNIT, zmm24);
if (InputTile > 24) _mm512_storeu_ps(destPtr + 24 * AVX512_PACK_C_UNIT, zmm25);
if (InputTile > 25) _mm512_storeu_ps(destPtr + 25 * AVX512_PACK_C_UNIT, zmm26);
if (InputTile > 26) _mm512_storeu_ps(destPtr + 26 * AVX512_PACK_C_UNIT, zmm27);
if (InputTile > 27) _mm512_storeu_ps(destPtr + 27 * AVX512_PACK_C_UNIT, zmm28);
if (InputTile > 28) _mm512_storeu_ps(destPtr + 28 * AVX512_PACK_C_UNIT, zmm29);
if (InputTile > 29) _mm512_storeu_ps(destPtr + 29 * AVX512_PACK_C_UNIT, zmm30);
if (InputTile > 30) _mm512_storeu_ps(destPtr + 30 * AVX512_PACK_C_UNIT, zmm31);
// oc < 32
auto writeDestPtr = destPtr + cStride;
inputTilePtr += aStride;
destPtr += InputTile * AVX512_PACK_C_UNIT;
}
#undef REDUCE_MUL_ADD
}