in source/backend/cpu/x86_x64/sse/GemmSSE.cpp [138:424]
void _SSE_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack, const float* bias) {
auto srcStep = realSize * pack;
if (pack == 4) { // core->pack
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
auto biasPtr = bias;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data0 = _mm_loadu_ps(srcPtr);
auto data1 = _mm_loadu_ps(srcPtr + pack);
auto data2 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm_loadu_ps(srcPtr + 3 * pack);
data0 = _mm_mul_ps(data0, scale0);
data1 = _mm_mul_ps(data1, scale1);
data2 = _mm_mul_ps(data2, scale2);
data3 = _mm_mul_ps(data3, scale3);
if (bias) {
auto bias0 = _mm_set1_ps(biasPtr[0]);
auto bias1 = _mm_set1_ps(biasPtr[1]);
auto bias2 = _mm_set1_ps(biasPtr[2]);
auto bias3 = _mm_set1_ps(biasPtr[3]);
data0 = _mm_add_ps(data0, bias0);
data1 = _mm_add_ps(data1, bias1);
data2 = _mm_add_ps(data2, bias2);
data3 = _mm_add_ps(data3, bias3);
}
data0 = _mm_round_ps(data0, 0);
data1 = _mm_round_ps(data1, 0);
data2 = _mm_round_ps(data2, 0);
data3 = _mm_round_ps(data3, 0);
auto r0 = _mm_cvtps_epi32(data0);
auto r1 = _mm_cvtps_epi32(data1);
auto r2 = _mm_cvtps_epi32(data2);
auto r3 = _mm_cvtps_epi32(data3);
r0 = _mm_add_epi32(r0, offset);
r1 = _mm_add_epi32(r1, offset);
r2 = _mm_add_epi32(r2, offset);
r3 = _mm_add_epi32(r3, offset);
auto r0_16 = _mm_packs_epi32(r0, r1); // 00001111
auto r1_16 = _mm_packs_epi32(r2, r3); // 22223333
auto r0_8 = _mm_packus_epi16(r0_16, r1_16); // 0000111122223333
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount -= 4;
scalePtr += 4;
if (bias) {
biasPtr += 4;
}
srcPtr += (4 * pack);
dstPtr += 4;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data0 = _mm_loadu_ps(srcPtr);
data0 = _mm_mul_ps(data0, scale0);
if (bias) {
auto bias0 = _mm_set1_ps(biasPtr[0]);
data0 = _mm_add_ps(data0, bias0);
}
auto r0 = _mm_cvtps_epi32(_mm_round_ps(data0, 0));
r0 = _mm_add_epi32(r0, offset);
auto r0_16 = _mm_packs_epi32(r0, r0); // 00001111
auto r0_8 = _mm_packus_epi16(r0_16, r0_16); // 0000111122223333
_mm_storeu_si128((__m128i *)tmp, r0_8);
dstPtr[0] = tmp[0];
// next round
xcount--;
scalePtr += 1;
if (bias) {
biasPtr += 1;
}
srcPtr += pack;
dstPtr += 1;
}
}
return;
}
if (pack == 16) {
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
auto biasPtr = bias;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data00 = _mm_loadu_ps(srcPtr);
auto data01 = _mm_loadu_ps(srcPtr + 4);
auto data02 = _mm_loadu_ps(srcPtr + 8);
auto data03 = _mm_loadu_ps(srcPtr + 12);
auto data10 = _mm_loadu_ps(srcPtr + pack);
auto data11 = _mm_loadu_ps(srcPtr + pack + 4);
auto data12 = _mm_loadu_ps(srcPtr + pack + 8);
auto data13 = _mm_loadu_ps(srcPtr + pack + 12);
auto data20 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data21 = _mm_loadu_ps(srcPtr + 2 * pack + 4);
auto data22 = _mm_loadu_ps(srcPtr + 2 * pack + 8);
auto data23 = _mm_loadu_ps(srcPtr + 2 * pack + 12);
auto data30 = _mm_loadu_ps(srcPtr + 3 * pack);
auto data31 = _mm_loadu_ps(srcPtr + 3 * pack + 4);
auto data32 = _mm_loadu_ps(srcPtr + 3 * pack + 8);
auto data33 = _mm_loadu_ps(srcPtr + 3 * pack + 12);
data00 = _mm_mul_ps(data00, scale0);
data01 = _mm_mul_ps(data01, scale0);
data02 = _mm_mul_ps(data02, scale0);
data03 = _mm_mul_ps(data03, scale0);
data10 = _mm_mul_ps(data10, scale1);
data11 = _mm_mul_ps(data11, scale1);
data12 = _mm_mul_ps(data12, scale1);
data13 = _mm_mul_ps(data13, scale1);
data20 = _mm_mul_ps(data20, scale2);
data21 = _mm_mul_ps(data21, scale2);
data22 = _mm_mul_ps(data22, scale2);
data23 = _mm_mul_ps(data23, scale2);
data30 = _mm_mul_ps(data30, scale3);
data31 = _mm_mul_ps(data31, scale3);
data32 = _mm_mul_ps(data32, scale3);
data33 = _mm_mul_ps(data33, scale3);
if (bias) {
auto bias0 = _mm_set1_ps(biasPtr[0]);
auto bias1 = _mm_set1_ps(biasPtr[1]);
auto bias2 = _mm_set1_ps(biasPtr[2]);
auto bias3 = _mm_set1_ps(biasPtr[3]);
data00 = _mm_add_ps(data00, bias0);
data01 = _mm_add_ps(data01, bias0);
data02 = _mm_add_ps(data02, bias0);
data03 = _mm_add_ps(data03, bias0);
data10 = _mm_add_ps(data10, bias1);
data11 = _mm_add_ps(data11, bias1);
data12 = _mm_add_ps(data12, bias1);
data13 = _mm_add_ps(data13, bias1);
data20 = _mm_add_ps(data20, bias2);
data21 = _mm_add_ps(data21, bias2);
data22 = _mm_add_ps(data22, bias2);
data23 = _mm_add_ps(data23, bias2);
data30 = _mm_add_ps(data30, bias3);
data31 = _mm_add_ps(data31, bias3);
data32 = _mm_add_ps(data32, bias3);
data33 = _mm_add_ps(data33, bias3);
}
data00 = _mm_round_ps(data00, 0);
data01 = _mm_round_ps(data01, 0);
data02 = _mm_round_ps(data02, 0);
data03 = _mm_round_ps(data03, 0);
data10 = _mm_round_ps(data10, 0);
data11 = _mm_round_ps(data11, 0);
data12 = _mm_round_ps(data12, 0);
data13 = _mm_round_ps(data13, 0);
data20 = _mm_round_ps(data20, 0);
data21 = _mm_round_ps(data21, 0);
data22 = _mm_round_ps(data22, 0);
data23 = _mm_round_ps(data23, 0);
data30 = _mm_round_ps(data30, 0);
data31 = _mm_round_ps(data31, 0);
data32 = _mm_round_ps(data32, 0);
data33 = _mm_round_ps(data33, 0);
auto r00 = _mm_cvtps_epi32(data00);
auto r01 = _mm_cvtps_epi32(data01);
auto r02 = _mm_cvtps_epi32(data02);
auto r03 = _mm_cvtps_epi32(data03);
auto r10 = _mm_cvtps_epi32(data10);
auto r11 = _mm_cvtps_epi32(data11);
auto r12 = _mm_cvtps_epi32(data12);
auto r13 = _mm_cvtps_epi32(data13);
auto r20 = _mm_cvtps_epi32(data20);
auto r21 = _mm_cvtps_epi32(data21);
auto r22 = _mm_cvtps_epi32(data22);
auto r23 = _mm_cvtps_epi32(data23);
auto r30 = _mm_cvtps_epi32(data30);
auto r31 = _mm_cvtps_epi32(data31);
auto r32 = _mm_cvtps_epi32(data32);
auto r33 = _mm_cvtps_epi32(data33);
r00 = _mm_add_epi32(r00, offset);
r01 = _mm_add_epi32(r01, offset);
r02 = _mm_add_epi32(r02, offset);
r03 = _mm_add_epi32(r03, offset);
r10 = _mm_add_epi32(r10, offset);
r11 = _mm_add_epi32(r11, offset);
r12 = _mm_add_epi32(r12, offset);
r13 = _mm_add_epi32(r13, offset);
r20 = _mm_add_epi32(r20, offset);
r21 = _mm_add_epi32(r21, offset);
r22 = _mm_add_epi32(r22, offset);
r23 = _mm_add_epi32(r23, offset);
r30 = _mm_add_epi32(r30, offset);
r31 = _mm_add_epi32(r31, offset);
r32 = _mm_add_epi32(r32, offset);
r33 = _mm_add_epi32(r33, offset);
auto r00_16 = _mm_packs_epi32(r00, r01); // 00000000
auto r01_16 = _mm_packs_epi32(r02, r03); // 00000000
auto r0_8 = _mm_packus_epi16(r00_16, r01_16); // 0000000000000000
auto r10_16 = _mm_packs_epi32(r10, r11);
auto r11_16 = _mm_packs_epi32(r12, r13);
auto r1_8 = _mm_packus_epi16(r10_16, r11_16);
auto r20_16 = _mm_packs_epi32(r20, r21);
auto r21_16 = _mm_packs_epi32(r22, r23);
auto r2_8 = _mm_packus_epi16(r20_16, r21_16);
auto r30_16 = _mm_packs_epi32(r30, r31);
auto r31_16 = _mm_packs_epi32(r32, r33);
auto r3_8 = _mm_packus_epi16(r30_16, r31_16);
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
_mm_storeu_si128((__m128i *)(dstPtr + 4), r1_8);
_mm_storeu_si128((__m128i *)(dstPtr + 8), r2_8);
_mm_storeu_si128((__m128i *)(dstPtr + 12), r3_8);
// next round
xcount -= 4;
scalePtr += 4;
if (bias) {
biasPtr += 4;
}
srcPtr += (4 * pack);
dstPtr += pack;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data00 = _mm_loadu_ps(srcPtr);
auto data01 = _mm_loadu_ps(srcPtr + 4);
auto data02 = _mm_loadu_ps(srcPtr + 8);
auto data03 = _mm_loadu_ps(srcPtr + 12);
data00 = _mm_mul_ps(data00, scale0);
data01 = _mm_mul_ps(data01, scale0);
data02 = _mm_mul_ps(data02, scale0);
data03 = _mm_mul_ps(data03, scale0);
if (bias) {
auto bias0 = _mm_set1_ps(biasPtr[0]);
data00 = _mm_add_ps(data00, bias0);
data01 = _mm_add_ps(data01, bias0);
data02 = _mm_add_ps(data02, bias0);
data03 = _mm_add_ps(data03, bias0);
}
data00 = _mm_round_ps(data00, 0);
data01 = _mm_round_ps(data01, 0);
data02 = _mm_round_ps(data02, 0);
data03 = _mm_round_ps(data03, 0);
auto r00 = _mm_cvtps_epi32(data00);
auto r01 = _mm_cvtps_epi32(data01);
auto r02 = _mm_cvtps_epi32(data02);
auto r03 = _mm_cvtps_epi32(data03);
r00 = _mm_add_epi32(r00, offset);
r01 = _mm_add_epi32(r01, offset);
r02 = _mm_add_epi32(r02, offset);
r03 = _mm_add_epi32(r03, offset);
auto r00_16 = _mm_packs_epi32(r00, r01); // 00000000
auto r01_16 = _mm_packs_epi32(r02, r03); // 00000000
auto r0_8 = _mm_packus_epi16(r00_16, r01_16); // 0000000000000000
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount--;
scalePtr += 1;
if (bias) {
biasPtr += 1;
}
srcPtr += pack;
dstPtr += 4;
}
}
return;
}
MNN_ERROR("dynamic quant error: x86_x64 sse don't suppport pack=%d yet\n", pack);
return;
}