in source/backend/cpu/x86_x64/avx/GemmInt8.cpp [1143:1425]
void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
auto zero128 = _mm256_set1_ps(0.0f);
auto minValue = _mm256_set1_ps(post->minValue);
auto maxValue = _mm256_set1_ps(post->maxValue);
auto plus = _mm256_set1_ps(0.5f);
auto minus = _mm256_set1_ps(-0.5f);
auto oneValue = _mm256_set1_epi16(1);
auto offset = _mm256_set1_epi32(128);
__m256 fp32min, fp32max;
if (0 == post->useInt8) {
fp32min = _mm256_set1_ps((post->fp32minmax)[0]);
fp32max = _mm256_set1_ps((post->fp32minmax)[1]);
}
auto srcKernelSumPtr = post->srcKernelSum;
__m256 kernelSum0 = _mm256_setzero_ps();
__m256 kernelSum1 = _mm256_setzero_ps();
__m256 kernelSum2 = _mm256_setzero_ps();
__m256 kernelSum3 = _mm256_setzero_ps();
if (GEMMINT8_AVX2_E == realDst) {
kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]);
kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]);
kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]);
kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]);
} else {
kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]);
if (realDst > 1) {
kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]);
}
if (realDst > 2) {
kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]);
}
}
int weight_step_Z = src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H;
int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D03 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
auto s3 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 3));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
D03 = _mm256_add_epi32(D03, _mm256_madd_epi16(_mm256_maddubs_epi16(s3, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto D2 = D02;
auto D3 = D03;
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
auto f3 = _mm256_cvtepi32_ps(D3);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
f3 = _mm256_mul_ps(f3, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2);
f3 = _mm256_add_ps(f3, xy0_3);
auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
f3 = _mm256_add_ps(f3, biasValue);
if (post->useInt8 == 0) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f2 = _mm256_min_ps(f2, fp32max);
f3 = _mm256_min_ps(f3, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
f3 = _mm256_max_ps(f3, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
POSTTREAT(3);
}
}
return;
}
if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
auto s2 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 2));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
D02 = _mm256_add_epi32(D02, _mm256_madd_epi16(_mm256_maddubs_epi16(s2, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto D2 = D02;
// auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
// D0 = _mm256_add_epi32(D0, biasValue0);
// D1 = _mm256_add_epi32(D1, biasValue0);
// D2 = _mm256_add_epi32(D2, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2);
auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
if (post->useInt8 == 0) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f2 = _mm256_min_ps(f2, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} else {
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
}
}
return;
}
if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
auto s1 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 1));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
D01 = _mm256_add_epi32(D01, _mm256_madd_epi16(_mm256_maddubs_epi16(s1, w0), oneValue));
}
auto D0 = D00;
auto D1 = D01;
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
if (post->useInt8 == 0) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} else {
POSTTREAT(0);
POSTTREAT(1);
}
}
return;
}
if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src;
auto dst_x = dst_z;
__m256i D00 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = _mm256_loadu_si256((__m256i*)weight_sz);
auto s0 = _mm256_castps_si256(_mm256_broadcast_ss((float*)src_z + 0));
D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue));
}
auto D0 = D00;
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
f0 = _mm256_mul_ps(f0, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0);
auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue);
if (post->useInt8 == 0) {
f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} else {
POSTTREAT(0);
}
}
return;
}
}