in source/backend/cpu/x86_x64/sse/GemmInt8.cpp [349:634]
void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
MNN_ASSERT(post->useInt8 == 0);
const auto dst_step_tmp = dst_step / sizeof(int8_t);
__m128i zero = _mm_set1_epi32(0);
__m128 minValue = _mm_set1_ps(post->minValue);
__m128 maxValue = _mm_set1_ps(post->maxValue);
__m128 fp32min, fp32max;
if (post->fp32minmax) {
fp32min = _mm_set1_ps((post->fp32minmax)[0]);
fp32max = _mm_set1_ps((post->fp32minmax)[1]);
}
const float* biasPtr = nullptr;
if (post->biasFloat) {
biasPtr = post->biasFloat;
}
auto accumbuff = post->accumBuffer;
auto blockNum = post->blockNum;
int weight_step_Z = 0.5 * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT;
int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
auto oneValue = _mm_set1_epi16(1);
auto offset = _mm_set1_epi32(128);
auto neg128f = _mm_set1_ps(-128.f);
auto srcKernelSumPtr = post->srcKernelSum;
__m128 kernelSum0 = _mm_setzero_ps();
__m128 kernelSum1 = _mm_setzero_ps();
__m128 kernelSum2 = _mm_setzero_ps();
__m128 kernelSum3 = _mm_setzero_ps();
__m128 extrabias0 = _mm_setzero_ps();
__m128 extrabias1 = _mm_setzero_ps();
__m128 extrabias2 = _mm_setzero_ps();
__m128 extrabias3 = _mm_setzero_ps();
const auto mask = _mm_set1_epi8(0xf);
if (GEMM_INT8_DST_XUNIT == realDst) {
kernelSum0 = _mm_load_ps1(post->srcKernelSum);
kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1);
kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2);
kernelSum3 = _mm_load_ps1(post->srcKernelSum + 3);
} else {
kernelSum0 = _mm_load_ps1(post->srcKernelSum);
if (realDst > 1) {
kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1);
}
if (realDst > 2) {
kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2);
}
}
auto f128 = _mm_set1_ps(128.f);
__m128 extrascale0 = _mm_setzero_ps();
__m128 extrascale1 = _mm_setzero_ps();
__m128 extrascale2 = _mm_setzero_ps();
__m128 extrascale3 = _mm_setzero_ps();
if (post->inputScale) {
if (GEMM_INT8_DST_XUNIT == realDst) {
extrascale0 = _mm_load_ps1(post->inputScale);
extrascale1 = _mm_load_ps1(post->inputScale + 1);
extrascale2 = _mm_load_ps1(post->inputScale + 2);
extrascale3 = _mm_load_ps1(post->inputScale + 3);
} else {
extrascale0 = _mm_load_ps1(post->inputScale);
if (realDst > 1) {
extrascale1 = _mm_load_ps1(post->inputScale + 1);
}
if (realDst > 2) {
extrascale2 = _mm_load_ps1(post->inputScale + 2);
}
}
}
__m128 bias0, bias1, bias2, bias3;
if (post->inputBias) {
if (GEMM_INT8_DST_XUNIT == realDst) {
extrabias0 = _mm_load_ps1(post->inputBias);
extrabias1 = _mm_load_ps1(post->inputBias + 1);
extrabias2 = _mm_load_ps1(post->inputBias + 2);
extrabias3 = _mm_load_ps1(post->inputBias + 3);
} else {
extrabias0 = _mm_load_ps1(post->inputBias);
if (realDst > 1) {
extrabias1 = _mm_load_ps1(post->inputBias + 1);
}
if (realDst > 2) {
extrabias2 = _mm_load_ps1(post->inputBias + 2);
}
}
}
for (int dz = 0; dz < dst_depth_quad; ++dz) {
auto dst_x = dst + dz * dst_step_tmp;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMM_INT8_SRC_UNIT * realDst;
__m128i d0 = _mm_set1_epi32(0);
__m128i d1 = _mm_set1_epi32(0);
__m128i d2 = _mm_set1_epi32(0);
__m128i d3 = _mm_set1_epi32(0);
__m128i e0 = _mm_set1_epi32(0);
__m128i e1 = _mm_set1_epi32(0);
__m128i e2 = _mm_set1_epi32(0);
__m128i e3 = _mm_set1_epi32(0);
__m128i D0 = _mm_set1_epi32(0);
__m128i D1 = _mm_set1_epi32(0);
__m128i D2 = _mm_set1_epi32(0);
__m128i D3 = _mm_set1_epi32(0);
__m128i E0 = _mm_set1_epi32(0);
__m128i E1 = _mm_set1_epi32(0);
__m128i E2 = _mm_set1_epi32(0);
__m128i E3 = _mm_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + weight_step_Y * sz;
const auto src_z = src_x + sz * realDst * GEMM_INT8_SRC_UNIT;
LOAD_INT4_TO_INT8;
auto s0 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 0));
auto s1 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 1));
auto s2 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 2));
auto s3 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 3));
//#define COMPUTE(i, j)\
//auto d##i##j = _mm_maddubs_epi16(s##i, w##j);\
//d##i##j = _mm_madd_epi16(d##i##j, oneValue);\
#define COMPUTE(i, j)\
auto W##i##j##0 = _mm_srai_epi16(_mm_unpacklo_epi8(zero, w##j), 8);\
auto W##i##j##1 = _mm_srai_epi16(_mm_unpackhi_epi8(zero, w##j), 8);\
auto S##i##j##0 = _mm_unpacklo_epi8(s##i, zero);\
auto S##i##j##1 = _mm_unpackhi_epi8(s##i, zero);\
auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_epi16(S##i##j##1, W##i##j##1));\
COMPUTE(0, 0);
COMPUTE(0, 1);
COMPUTE(0, 2);
COMPUTE(0, 3);
COMPUTE(1, 0);
COMPUTE(1, 1);
COMPUTE(1, 2);
COMPUTE(1, 3);
COMPUTE(2, 0);
COMPUTE(2, 1);
COMPUTE(2, 2);
COMPUTE(2, 3);
COMPUTE(3, 0);
COMPUTE(3, 1);
COMPUTE(3, 2);
COMPUTE(3, 3);
d0 = _mm_add_epi32(d0, d00);
d1 = _mm_add_epi32(d1, d01);
d2 = _mm_add_epi32(d2, d02);
d3 = _mm_add_epi32(d3, d03);
e0 = _mm_add_epi32(e0, d10);
e1 = _mm_add_epi32(e1, d11);
e2 = _mm_add_epi32(e2, d12);
e3 = _mm_add_epi32(e3, d13);
D0 = _mm_add_epi32(D0, d20);
D1 = _mm_add_epi32(D1, d21);
D2 = _mm_add_epi32(D2, d22);
D3 = _mm_add_epi32(D3, d23);
E0 = _mm_add_epi32(E0, d30);
E1 = _mm_add_epi32(E1, d31);
E2 = _mm_add_epi32(E2, d32);
E3 = _mm_add_epi32(E3, d33);
}
d0 = _mm_hadd_epi32(d0, d1);
d1 = _mm_hadd_epi32(d2, d3);
d0 = _mm_hadd_epi32(d0, d1);
e0 = _mm_hadd_epi32(e0, e1);
e1 = _mm_hadd_epi32(e2, e3);
d1 = _mm_hadd_epi32(e0, e1);
D0 = _mm_hadd_epi32(D0, D1);
D1 = _mm_hadd_epi32(D2, D3);
d2 = _mm_hadd_epi32(D0, D1);
E0 = _mm_hadd_epi32(E0, E1);
E1 = _mm_hadd_epi32(E2, E3);
d3 = _mm_hadd_epi32(E0, E1);
auto scaleValue = _mm_loadu_ps(scale_dz);
auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz);
__m128 f0 = _mm_cvtepi32_ps(d0);
__m128 f1 = _mm_cvtepi32_ps(d1);
__m128 f2 = _mm_cvtepi32_ps(d2);
__m128 f3 = _mm_cvtepi32_ps(d3);
kernelSum0 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[0]);
if (realDst > 1) kernelSum1 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[1]);
if (realDst > 2) kernelSum2 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[2]);
if (realDst > 3) kernelSum3 = _mm_set1_ps((post->srcKernelSum + bk * realDst)[3]);
auto xy0_0 = _mm_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm_mul_ps(kernelSum1, weightBiasValue); // ..second
auto xy0_2 = _mm_mul_ps(kernelSum2, weightBiasValue); // .. third
auto xy0_3 = _mm_mul_ps(kernelSum3, weightBiasValue); // ..fourth
f0 = _mm_mul_ps(f0, scaleValue);
f1 = _mm_mul_ps(f1, scaleValue);
f2 = _mm_mul_ps(f2, scaleValue);
f3 = _mm_mul_ps(f3, scaleValue);
if (post->inputScale) {
if (post->inputBias) {
extrascale0 = _mm_set1_ps((post->inputScale + bk * realDst)[0]);
if (realDst > 1) extrascale1 = _mm_set1_ps((post->inputScale + bk * realDst)[1]);
if (realDst > 2) extrascale2 = _mm_set1_ps((post->inputScale + bk * realDst)[2]);
if (realDst > 3) extrascale3 = _mm_set1_ps((post->inputScale + bk * realDst)[3]);
}
f0 = _mm_mul_ps(f0, extrascale0);
f1 = _mm_mul_ps(f1, extrascale1);
f2 = _mm_mul_ps(f2, extrascale2);
f3 = _mm_mul_ps(f3, extrascale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
if (post->inputBias) {
auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT;
auto wsum = _mm_loadu_ps(wsumDz);
extrabias0 = _mm_set1_ps((post->inputBias + bk * realDst)[0]);
if (realDst > 1) extrabias1 = _mm_set1_ps((post->inputBias + bk * realDst)[1]);
if (realDst > 2) extrabias2 = _mm_set1_ps((post->inputBias + bk * realDst)[2]);
if (realDst > 3) extrabias3 = _mm_set1_ps((post->inputBias + bk * realDst)[3]);
bias0 = _mm_mul_ps(extrabias0, wsum);
bias1 = _mm_mul_ps(extrabias1, wsum);
bias2 = _mm_mul_ps(extrabias2, wsum);
bias3 = _mm_mul_ps(extrabias3, wsum);
} else if (bk == blockNum - 1) { // if input not block quant, only accum once!
auto wsumDz = post->weightKernelSum + dz * GEMM_INT8_UNIT;
auto wsum = _mm_loadu_ps(wsumDz);
bias0 = _mm_mul_ps(_mm_mul_ps(extrascale0, neg128f), wsum);
bias1 = _mm_mul_ps(_mm_mul_ps(extrascale1, neg128f), wsum);
bias2 = _mm_mul_ps(_mm_mul_ps(extrascale2, neg128f), wsum);
bias3 = _mm_mul_ps(_mm_mul_ps(extrascale3, neg128f), wsum);
}
f0 = _mm_add_ps(f0, bias0);
f1 = _mm_add_ps(f1, bias1);
f2 = _mm_add_ps(f2, bias2);
f3 = _mm_add_ps(f3, bias3);
}
}
f0 = _mm_add_ps(f0, xy0_0);
f1 = _mm_add_ps(f1, xy0_1);
f2 = _mm_add_ps(f2, xy0_2);
f3 = _mm_add_ps(f3, xy0_3);
__m128 f[4] = {f0, f1, f2, f3};
if (bk > 0) {
for (int j = 0; j < realDst; ++j) {
auto dstv = _mm_loadu_ps(((float*)accum_x) + j * 4);
f[j] = _mm_add_ps(dstv, f[j]);
}
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT;
auto biasValue = _mm_loadu_ps(bias_dz);
for (int j = 0; j < realDst; ++j) {
f[j] = _mm_add_ps(biasValue, f[j]);
}
}
if (post->fp32minmax) {
for (int j = 0; j < realDst; ++j) {
f[j] = _mm_min_ps(f[j], fp32max);
f[j] = _mm_max_ps(f[j], fp32min);
}
}
for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)dst_x) + j * 4, f[j]);
}
} else {
for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)accum_x) + j * 4, f[j]);
}
}
}
}
}