in source/backend/cpu/x86_x64/avx/GemmInt8.cpp [560:1142]
void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
auto zero128 = _mm256_set1_ps(0.0f);
auto minValue = _mm256_set1_ps(post->minValue);
auto maxValue = _mm256_set1_ps(post->maxValue);
auto plus = _mm256_set1_ps(0.5f);
auto minus = _mm256_set1_ps(-0.5f);
auto offset = _mm256_set1_epi32(128);
__m256 fp32min, fp32max;
if (0 == post->useInt8 && post->fp32minmax) {
fp32min = _mm256_set1_ps((post->fp32minmax)[0]);
fp32max = _mm256_set1_ps((post->fp32minmax)[1]);
}
const float* biasPtr = nullptr;
if (post->biasFloat) {
biasPtr = post->biasFloat;
}
int inputBlockNum = 1;
auto accumbuff = post->accumBuffer;
auto blockNum = post->blockNum;
if (post->inputBias) {
inputBlockNum = blockNum;
}
int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float) * GEMMINT8_AVX2_H;
auto srcKernelSumPtr = post->srcKernelSum;
__m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3;
auto neg128_f = _mm256_set1_ps(-128.f);
__m256 extrascale0 = _mm256_setzero_ps();
__m256 extrascale1 = _mm256_setzero_ps();
__m256 extrascale2 = _mm256_setzero_ps();
__m256 extrascale3 = _mm256_setzero_ps();
__m256 extrabias0 = _mm256_setzero_ps();
__m256 extrabias1 = _mm256_setzero_ps();
__m256 extrabias2 = _mm256_setzero_ps();
__m256 extrabias3 = _mm256_setzero_ps();
if (post->inputScale) {
if (GEMMINT8_AVX2_E == realDst) {
extrascale0 = _mm256_set1_ps(post->inputScale[0]);
extrascale1 = _mm256_set1_ps(post->inputScale[1]);
extrascale2 = _mm256_set1_ps(post->inputScale[2]);
extrascale3 = _mm256_set1_ps(post->inputScale[3]);
} else {
extrascale0 = _mm256_set1_ps(post->inputScale[0]);
if (realDst > 1) {
extrascale1 = _mm256_set1_ps(post->inputScale[1]);
}
if (realDst > 2) {
extrascale2 = _mm256_set1_ps(post->inputScale[2]);
}
}
}
__m256 bias0, bias1, bias2, bias3;
if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
auto dst_x = dst + dz * dst_step_tmp;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);
kernelSum3 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[3]);
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D03 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
__m256i D12 = _mm256_set1_epi32(0);
__m256i D13 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
auto s3 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 3));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
auto S2 = _mm256_cvtepu8_epi16(s2);
auto S3 = _mm256_cvtepu8_epi16(s3);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
COMPUTE(0, 2);
COMPUTE(1, 2);
COMPUTE(0, 3);
COMPUTE(1, 3);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto D2 = NORMAL_HADD(D02, D12);
auto D3 = NORMAL_HADD(D03, D13);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
auto f3 = _mm256_cvtepi32_ps(D3);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
f3 = _mm256_mul_ps(f3, scaleValue);
if (post->inputScale) {
if (post->inputBias) {
extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
extrascale3 = _mm256_set1_ps((post->inputScale + bk * realDst)[3]);
}
f0 = _mm256_mul_ps(f0, extrascale0);
f1 = _mm256_mul_ps(f1, extrascale1);
f2 = _mm256_mul_ps(f2, extrascale2);
f3 = _mm256_mul_ps(f3, extrascale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
if (post->inputBias) {
auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
extrabias3 = _mm256_set1_ps((post->inputBias + bk * realDst)[3]);
bias0 = _mm256_mul_ps(extrabias0, wsum);
bias1 = _mm256_mul_ps(extrabias1, wsum);
bias2 = _mm256_mul_ps(extrabias2, wsum);
bias3 = _mm256_mul_ps(extrabias3, wsum);
} else if (bk == blockNum - 1) { // if input not block quant, only accum once!
auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
bias3 = _mm256_mul_ps(_mm256_mul_ps(extrascale3, neg128_f), wsum);
}
f0 = _mm256_add_ps(f0, bias0);
f1 = _mm256_add_ps(f1, bias1);
f2 = _mm256_add_ps(f2, bias2);
f3 = _mm256_add_ps(f3, bias3);
}
}
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2);
f3 = _mm256_add_ps(f3, xy0_3);
if (post->useInt8 == 1) {
if (biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
f3 = _mm256_add_ps(f3, biasValue);
}
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
POSTTREAT(3);
} else {
if (bk > 0) {
auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2);
f3 = _mm256_add_ps(f3, dstv3);
}
if (bk == blockNum - 1) {
if (biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
f3 = _mm256_add_ps(f3, biasValue);
}
if (post->fp32minmax) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f2 = _mm256_min_ps(f2, fp32max);
f3 = _mm256_min_ps(f3, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
f3 = _mm256_max_ps(f3, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3);
}
}
}
}
return;
}
if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
auto dst_x = dst + dz * dst_step_tmp;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[2]);
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
__m256i D12 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * weight_step_Y;
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
auto S2 = _mm256_cvtepu8_epi16(s2);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
COMPUTE(0, 2);
COMPUTE(1, 2);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto D2 = NORMAL_HADD(D02, D12);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
auto f2 = _mm256_cvtepi32_ps(D2);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
f2 = _mm256_mul_ps(f2, scaleValue);
if (post->inputScale) {
if (post->inputBias) {
extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
extrascale2 = _mm256_set1_ps((post->inputScale + bk * realDst)[2]);
}
f0 = _mm256_mul_ps(f0, extrascale0);
f1 = _mm256_mul_ps(f1, extrascale1);
f2 = _mm256_mul_ps(f2, extrascale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
if (post->inputBias) {
auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
extrabias2 = _mm256_set1_ps((post->inputBias + bk * realDst)[2]);
bias0 = _mm256_mul_ps(extrabias0, wsum);
bias1 = _mm256_mul_ps(extrabias1, wsum);
bias2 = _mm256_mul_ps(extrabias2, wsum);
} else if (bk == blockNum - 1) { // if input not block quant, only accum once!
auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
bias2 = _mm256_mul_ps(_mm256_mul_ps(extrascale2, neg128_f), wsum);
}
f0 = _mm256_add_ps(f0, bias0);
f1 = _mm256_add_ps(f1, bias1);
f2 = _mm256_add_ps(f2, bias2);
}
}
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
}
POSTTREAT(0);
POSTTREAT(1);
POSTTREAT(2);
} else {
if (bk > 0) {
auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2);
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue);
}
if (post->fp32minmax) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f2 = _mm256_min_ps(f2, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
}
}
}
}
return;
}
if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
auto dst_x = dst + dz * dst_step_tmp;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[1]);
__m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
__m256i D11 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * weight_step_Y;
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1));
auto S0 = _mm256_cvtepu8_epi16(s0);
auto S1 = _mm256_cvtepu8_epi16(s1);
COMPUTE(0, 0);
COMPUTE(1, 0);
COMPUTE(0, 1);
COMPUTE(1, 1);
}
auto D0 = NORMAL_HADD(D00, D10);
auto D1 = NORMAL_HADD(D01, D11);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
auto f1 = _mm256_cvtepi32_ps(D1);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second
f0 = _mm256_mul_ps(f0, scaleValue);
f1 = _mm256_mul_ps(f1, scaleValue);
if (post->inputScale) {
if (post->inputBias) {
extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
extrascale1 = _mm256_set1_ps((post->inputScale + bk * realDst)[1]);
}
f0 = _mm256_mul_ps(f0, extrascale0);
f1 = _mm256_mul_ps(f1, extrascale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
if (post->inputBias) {
auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
extrabias1 = _mm256_set1_ps((post->inputBias + bk * realDst)[1]);
bias0 = _mm256_mul_ps(extrabias0, wsum);
bias1 = _mm256_mul_ps(extrabias1, wsum);
} else if (bk == blockNum - 1) { // if input not block quant, only accum once!
auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
bias1 = _mm256_mul_ps(_mm256_mul_ps(extrascale1, neg128_f), wsum);
}
f0 = _mm256_add_ps(f0, bias0);
f1 = _mm256_add_ps(f1, bias1);
}
}
f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
}
POSTTREAT(0);
POSTTREAT(1);
} else {
if (bk > 0) {
auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1);
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue);
}
if (post->fp32minmax) {
f0 = _mm256_min_ps(f0, fp32max);
f1 = _mm256_min_ps(f1, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
}
}
}
}
return;
}
if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) {
auto dst_x = dst + dz * dst_step_tmp;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX2_L * realDst;
// source kernel sum
kernelSum0 = _mm256_set1_ps((post->srcKernelSum + bk * realDst)[0]);
__m256i D00 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + sz * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const auto src_z = src_x + sz * GEMMINT8_AVX2_L * realDst;
auto w0 = mm_loadu_si128(weight_sz + 16 * 0);
auto w1 = mm_loadu_si128(weight_sz + 16 * 1);
auto W0 = _mm256_cvtepi8_epi16(w0);
auto W1 = _mm256_cvtepi8_epi16(w1);
auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0));
auto S0 = _mm256_cvtepu8_epi16(s0);
COMPUTE(0, 0);
COMPUTE(1, 0);
}
auto D0 = NORMAL_HADD(D00, D10);
auto scaleValue = _mm256_loadu_ps(scale_dz);
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
auto f0 = _mm256_cvtepi32_ps(D0);
// x_kernelSum x w_quantZero
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
f0 = _mm256_mul_ps(f0, scaleValue);
if (post->inputScale) {
if (post->inputBias) {
extrascale0 = _mm256_set1_ps((post->inputScale + bk * realDst)[0]);
}
f0 = _mm256_mul_ps(f0, extrascale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == blockNum - 1))) {
if (post->inputBias) {
auto wsumDz = post->weightKernelSum + dz * (blockNum * GEMMINT8_AVX2_H) + bk * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
extrabias0 = _mm256_set1_ps((post->inputBias + bk * realDst)[0]);
bias0 = _mm256_mul_ps(extrabias0, wsum);
} else if (bk == blockNum - 1) { // if input not block quant, only accum once!
auto wsumDz = post->weightKernelSum + dz * GEMMINT8_AVX2_H;
auto wsum = _mm256_loadu_ps(wsumDz);
bias0 = _mm256_mul_ps(_mm256_mul_ps(extrascale0, neg128_f), wsum);
}
f0 = _mm256_add_ps(f0, bias0);
}
}
f0 = _mm256_add_ps(f0, xy0_0);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
}
POSTTREAT(0);
} else {
if (bk > 0) {
auto dstv = _mm256_loadu_ps(((float*)accum_x));
f0 = _mm256_add_ps(f0, dstv);
}
if (bk == blockNum - 1) {
if (biasPtr) {
const auto bias_dz = biasPtr + dz * AVX2_PACKINT8;
auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue);
}
if (post->fp32minmax) {
f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
}
_mm256_storeu_ps(((float*)dst_x), f0);
} else {
_mm256_storeu_ps(((float*)accum_x) , f0);
}
}
}
}
return;
}
}