in source/backend/cpu/x86_x64/avx512/GemmInt8_VNNI.cpp [106:1610]
void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) {
const auto dst_step_tmp = dst_step / sizeof(int8_t);
// common
int dzUnit = GEMMINT8_AVX512_H / PACK_UNIT;
int dzU = dst_depth_quad / dzUnit;
int dzR = dst_depth_quad % dzUnit;
const float* biasPtr = nullptr;
const float* bias_dz = nullptr;
if (post->biasFloat) {
biasPtr = post->biasFloat;
}
// int8 output relevant.
auto zero512 = _mm512_set1_ps(0.0f);
auto minValue = _mm512_set1_ps(post->minValue);
auto maxValue = _mm512_set1_ps(post->maxValue);
auto plus = _mm512_set1_ps(0.5f);
auto minus = _mm512_set1_ps(-0.5f);
auto offset = _mm256_set1_epi16(128);
// float outout relevant
auto neg128f = _mm512_set1_ps(-128.f);
__m512 bias00, bias10, bias20, bias30, bias01, bias02, bias03, bias11, bias12, bias13, bias21, bias22, bias23, bias31, bias32, bias33;
__m512 fp32min, fp32max;
if (0 == post->useInt8 && post->fp32minmax) {
fp32min = _mm512_set1_ps((post->fp32minmax)[0]);
fp32max = _mm512_set1_ps((post->fp32minmax)[1]);
}
auto blockNum = post->blockNum;
const float* weightKernelSum_dz = nullptr;
auto accumbuff = post->accumBuffer;
__m512 kernelSum0, kernelSum1, kernelSum2, kernelSum3;
__m512 inputbias0, inputbias1, inputbias2, inputbias3;
__m512 inputscale0, inputscale1, inputscale2, inputscale3;
if (post->inputScale) {
inputscale0 = _mm512_set1_ps(post->inputScale[0]);
if (realDst > 1) {
inputscale1 = _mm512_set1_ps(post->inputScale[1]);
}
if (realDst > 2) {
inputscale2 = _mm512_set1_ps(post->inputScale[2]);
}
if (realDst > 3) {
inputscale3 = _mm512_set1_ps(post->inputScale[3]);
}
}
int weight_step_Z = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * GEMMINT8_AVX512_H);
int weight_step_Y = static_cast<int32_t>(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H);
int weightPackStride = GEMMINT8_AVX512_L * PACK_UNIT;
int weight_step_Z_remain = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * dzR * PACK_UNIT);
int source_step = realDst * PACK_UNIT;
if (realDst == GEMMINT8_AVX512_E) {
for (int dz = 0; dz < dzU; ++dz) {
if (biasPtr) {
bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dz * dst_step_tmp * dzUnit;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
__m512i D4 = _mm512_set1_epi32(0);
__m512i D5 = _mm512_set1_epi32(0);
__m512i D6 = _mm512_set1_epi32(0);
__m512i D7 = _mm512_set1_epi32(0);
__m512i D8 = _mm512_set1_epi32(0);
__m512i D9 = _mm512_set1_epi32(0);
__m512i D10 = _mm512_set1_epi32(0);
__m512i D11 = _mm512_set1_epi32(0);
__m512i D12 = _mm512_set1_epi32(0);
__m512i D13 = _mm512_set1_epi32(0);
__m512i D14 = _mm512_set1_epi32(0);
__m512i D15 = _mm512_set1_epi32(0);
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX512_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
auto s2 = AVX512_BROADCAST_INT32(src_z + 2);
auto s3 = AVX512_BROADCAST_INT32(src_z + 3);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
D2 = _mm512_dpbusds_epi32(D2, s2, w0);
D3 = _mm512_dpbusds_epi32(D3, s3, w0);
D4 = _mm512_dpbusds_epi32(D4, s0, w1);
D5 = _mm512_dpbusds_epi32(D5, s1, w1);
D6 = _mm512_dpbusds_epi32(D6, s2, w1);
D7 = _mm512_dpbusds_epi32(D7, s3, w1);
D8 = _mm512_dpbusds_epi32(D8, s0, w2);
D9 = _mm512_dpbusds_epi32(D9, s1, w2);
D10 = _mm512_dpbusds_epi32(D10, s2, w2);
D11 = _mm512_dpbusds_epi32(D11, s3, w2);
D12 = _mm512_dpbusds_epi32(D12, s0, w3);
D13 = _mm512_dpbusds_epi32(D13, s1, w3);
D14 = _mm512_dpbusds_epi32(D14, s2, w3);
D15 = _mm512_dpbusds_epi32(D15, s3, w3);
}
// int32_t -> float
auto scaleValue0 = _mm512_loadu_ps(scale_dz);
auto scaleValue1 = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT);
auto scaleValue2 = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT);
auto scaleValue3 = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(weightBias_dz);
auto weightBiasValue1 = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT);
auto weightBiasValue2 = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT);
auto weightBiasValue3 = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[2]);
kernelSum3 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[3]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputscale2 = _mm512_set1_ps((post->inputScale + bk * realDst)[2]);
inputscale3 = _mm512_set1_ps((post->inputScale + bk * realDst)[3]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
inputbias2 = _mm512_set1_ps((post->inputBias + bk * realDst)[2]);
inputbias3 = _mm512_set1_ps((post->inputBias + bk * realDst)[3]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
MUL_WEIGHT_SCALE(2, 0);
MUL_WEIGHT_SCALE(3, 0);
MUL_WEIGHT_SCALE(4, 1);
MUL_WEIGHT_SCALE(5, 1);
MUL_WEIGHT_SCALE(6, 1);
MUL_WEIGHT_SCALE(7, 1);
MUL_WEIGHT_SCALE(8, 2);
MUL_WEIGHT_SCALE(9, 2);
MUL_WEIGHT_SCALE(10, 2);
MUL_WEIGHT_SCALE(11, 2);
MUL_WEIGHT_SCALE(12, 3);
MUL_WEIGHT_SCALE(13, 3);
MUL_WEIGHT_SCALE(14, 3);
MUL_WEIGHT_SCALE(15, 3);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
f2 = _mm512_mul_ps(f2, inputscale2);
f3 = _mm512_mul_ps(f3, inputscale3);
f4 = _mm512_mul_ps(f4, inputscale0);
f5 = _mm512_mul_ps(f5, inputscale1);
f6 = _mm512_mul_ps(f6, inputscale2);
f7 = _mm512_mul_ps(f7, inputscale3);
f8 = _mm512_mul_ps(f8, inputscale0);
f9 = _mm512_mul_ps(f9, inputscale1);
f10 = _mm512_mul_ps(f10, inputscale2);
f11 = _mm512_mul_ps(f11, inputscale3);
f12 = _mm512_mul_ps(f12, inputscale0);
f13 = _mm512_mul_ps(f13, inputscale1);
f14 = _mm512_mul_ps(f14, inputscale2);
f15 = _mm512_mul_ps(f15, inputscale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + bk * GEMMINT8_AVX512_H + dz * blockNum * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
bias02 = _mm512_mul_ps(inputbias2, wsum0);
bias03 = _mm512_mul_ps(inputbias3, wsum0);
bias10 = _mm512_mul_ps(inputbias0, wsum1);
bias11 = _mm512_mul_ps(inputbias1, wsum1);
bias12 = _mm512_mul_ps(inputbias2, wsum1);
bias13 = _mm512_mul_ps(inputbias3, wsum1);
bias20 = _mm512_mul_ps(inputbias0, wsum2);
bias21 = _mm512_mul_ps(inputbias1, wsum2);
bias22 = _mm512_mul_ps(inputbias2, wsum2);
bias23 = _mm512_mul_ps(inputbias3, wsum2);
bias30 = _mm512_mul_ps(inputbias0, wsum3);
bias31 = _mm512_mul_ps(inputbias1, wsum3);
bias32 = _mm512_mul_ps(inputbias2, wsum3);
bias33 = _mm512_mul_ps(inputbias3, wsum3);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dz * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
bias02 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum0);
bias03 = _mm512_mul_ps(_mm512_mul_ps(inputscale3, neg128f), wsum0);
bias10 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum1);
bias11 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum1);
bias12 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum1);
bias13 = _mm512_mul_ps(_mm512_mul_ps(inputscale3, neg128f), wsum1);
bias20 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum2);
bias21 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum2);
bias22 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum2);
bias23 = _mm512_mul_ps(_mm512_mul_ps(inputscale3, neg128f), wsum2);
bias30 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum3);
bias31 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum3);
bias32 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum3);
bias33 = _mm512_mul_ps(_mm512_mul_ps(inputscale3, neg128f), wsum3);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
f2 = _mm512_add_ps(f2, bias02);
f3 = _mm512_add_ps(f3, bias03);
f4 = _mm512_add_ps(f4, bias10);
f5 = _mm512_add_ps(f5, bias11);
f6 = _mm512_add_ps(f6, bias12);
f7 = _mm512_add_ps(f7, bias13);
f8 = _mm512_add_ps(f8, bias20);
f9 = _mm512_add_ps(f9, bias21);
f10 = _mm512_add_ps(f10, bias22);
f11 = _mm512_add_ps(f11, bias23);
f12 = _mm512_add_ps(f12, bias30);
f13 = _mm512_add_ps(f13, bias31);
f14 = _mm512_add_ps(f14, bias32);
f15 = _mm512_add_ps(f15, bias33);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
f2 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue0), f2);
f3 = _mm512_add_ps(_mm512_mul_ps(kernelSum3, weightBiasValue0), f3);
f4 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue1), f4);
f5 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue1), f5);
f6 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue1), f6);
f7 = _mm512_add_ps(_mm512_mul_ps(kernelSum3, weightBiasValue1), f7);
f8 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue2), f8);
f9 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue2), f9);
f10 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue2),f10);
f11 = _mm512_add_ps(_mm512_mul_ps(kernelSum3, weightBiasValue2),f11);
f12 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue3),f12);
f13 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue3),f13);
f14 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue3),f14);
f15 = _mm512_add_ps(_mm512_mul_ps(kernelSum3, weightBiasValue3),f15);
if (post->useInt8 == 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f3 = _mm512_add_ps(f3, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f6 = _mm512_add_ps(f6, biasValue4);
f7 = _mm512_add_ps(f7, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f10 = _mm512_add_ps(f10, biasValue8);
f11 = _mm512_add_ps(f11, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
f14 = _mm512_add_ps(f14, biasValue12);
f15 = _mm512_add_ps(f15, biasValue12);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
POSTTREAT(2, 2);
POSTTREAT(3, 3);
dst_x += dst_step_tmp;
POSTTREAT(4, 0);
POSTTREAT(5, 1);
POSTTREAT(6, 2);
POSTTREAT(7, 3);
dst_x += dst_step_tmp;
POSTTREAT(8, 0);
POSTTREAT(9, 1);
POSTTREAT(10, 2);
POSTTREAT(11, 3);
dst_x += dst_step_tmp;
POSTTREAT(12, 0);
POSTTREAT(13, 1);
POSTTREAT(14, 2);
POSTTREAT(15, 3);
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps(accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16), f1);
f2 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16 * 2), f2);
f3 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16 * 3), f3);
f4 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step), f4);
f5 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 1), f5);
f6 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 2), f6);
f7 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 3), f7);
f8 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step), f8);
f9 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 1), f9);
f10 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 2), f10);
f11 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 3), f11);
f12 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step), f12);
f13 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 1), f13);
f14 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 2), f14);
f15 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 3), f15);
}
if (bk == blockNum - 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f3 = _mm512_add_ps(f3, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f6 = _mm512_add_ps(f6, biasValue4);
f7 = _mm512_add_ps(f7, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f10 = _mm512_add_ps(f10, biasValue8);
f11 = _mm512_add_ps(f11, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
f14 = _mm512_add_ps(f14, biasValue12);
f15 = _mm512_add_ps(f15, biasValue12);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT(0,1,2,3);
POST_TREAT_FLOAT(4,5,6,7);
POST_TREAT_FLOAT(8,9,10,11);
POST_TREAT_FLOAT(12,13,14,15);
}
_mm512_storeu_ps(((float*)dst_x), f0);
_mm512_storeu_ps(((float*)dst_x) + 16, f1);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2);
_mm512_storeu_ps(((float*)dst_x) + 16 * 3, f3);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6);
_mm512_storeu_ps(((float*)dst_x) + 16 * 3, f7);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10);
_mm512_storeu_ps(((float*)dst_x) + 16 * 3, f11);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14);
_mm512_storeu_ps(((float*)dst_x) + 16 * 3, f15);
} else {
_mm512_storeu_ps(accum_x, f0);
_mm512_storeu_ps(accum_x + 16, f1);
_mm512_storeu_ps(accum_x + 16 * 2, f2);
_mm512_storeu_ps(accum_x + 16 * 3, f3);
_mm512_storeu_ps(accum_x + source_step, f4);
_mm512_storeu_ps(accum_x + source_step + 16 * 1, f5);
_mm512_storeu_ps(accum_x + source_step + 16 * 2, f6);
_mm512_storeu_ps(accum_x + source_step + 16 * 3, f7);
_mm512_storeu_ps(accum_x + 2 * source_step, f8);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 1, f9);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 2, f10);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 3, f11);
_mm512_storeu_ps(accum_x + 3 * source_step, f12);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 1, f13);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 2, f14);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 3, f15);
}
}
} // dzU
// the remaining ocDivPack
if (dzR == 0) {
return;
}
auto weight_dz = weight + dzU * blockNum * weight_step_Z;
if (biasPtr) {
bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dzU * dst_step_tmp * dzUnit;
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weightDzSub + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
auto s2 = AVX512_BROADCAST_INT32(src_z + 2);
auto s3 = AVX512_BROADCAST_INT32(src_z + 3);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
D2 = _mm512_dpbusds_epi32(D2, s2, w0);
D3 = _mm512_dpbusds_epi32(D3, s3, w0);
}
auto scaleValue0 = _mm512_loadu_ps(scaleDz + i * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(biasDz + i * PACK_UNIT);
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
MUL_WEIGHT_SCALE(2, 0);
MUL_WEIGHT_SCALE(3, 0);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[2]);
kernelSum3 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[3]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputscale2 = _mm512_set1_ps((post->inputScale + bk * realDst)[2]);
inputscale3 = _mm512_set1_ps((post->inputScale + bk * realDst)[3]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
inputbias2 = _mm512_set1_ps((post->inputBias + bk * realDst)[2]);
inputbias3 = _mm512_set1_ps((post->inputBias + bk * realDst)[3]);
}
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
f2 = _mm512_mul_ps(f2, inputscale2);
f3 = _mm512_mul_ps(f3, inputscale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
bias02 = _mm512_mul_ps(inputbias2, wsum0);
bias03 = _mm512_mul_ps(inputbias3, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dzU * PACK_UNIT * dzUnit;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
bias02 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum0);
bias03 = _mm512_mul_ps(_mm512_mul_ps(inputscale3, neg128f), wsum0);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
f2 = _mm512_add_ps(f2, bias02);
f3 = _mm512_add_ps(f3, bias03);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
f2 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue0), f2);
f3 = _mm512_add_ps(_mm512_mul_ps(kernelSum3, weightBiasValue0), f3);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f3 = _mm512_add_ps(f3, biasValue0);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
POSTTREAT(2, 2);
POSTTREAT(3, 3);
dst_x += dst_step_tmp;
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps((float*)accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16), f1);
f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16 * 2), f2);
f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16 * 3), f3);
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f3 = _mm512_add_ps(f3, biasValue0);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT(0,1,2,3);
}
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp), f0);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16, f1);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16 * 2, f2);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16 * 3, f3);
} else {
_mm512_storeu_ps(((float*)accum_x), f0);
_mm512_storeu_ps(((float*)accum_x) + 16, f1);
_mm512_storeu_ps(((float*)accum_x) + 16 * 2, f2);
_mm512_storeu_ps(((float*)accum_x) + 16 * 3, f3);
}
}
}
return;
}
if (realDst == 3) {
for (int dz = 0; dz < dzU; ++dz) {
if (biasPtr) {
bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dz * dst_step_tmp * dzUnit;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D4 = _mm512_set1_epi32(0);
__m512i D5 = _mm512_set1_epi32(0);
__m512i D6 = _mm512_set1_epi32(0);
__m512i D8 = _mm512_set1_epi32(0);
__m512i D9 = _mm512_set1_epi32(0);
__m512i D10 = _mm512_set1_epi32(0);
__m512i D12 = _mm512_set1_epi32(0);
__m512i D13 = _mm512_set1_epi32(0);
__m512i D14 = _mm512_set1_epi32(0);
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX512_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
auto s2 = AVX512_BROADCAST_INT32(src_z + 2);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
D2 = _mm512_dpbusds_epi32(D2, s2, w0);
D4 = _mm512_dpbusds_epi32(D4, s0, w1);
D5 = _mm512_dpbusds_epi32(D5, s1, w1);
D6 = _mm512_dpbusds_epi32(D6, s2, w1);
D8 = _mm512_dpbusds_epi32(D8, s0, w2);
D9 = _mm512_dpbusds_epi32(D9, s1, w2);
D10 = _mm512_dpbusds_epi32(D10, s2, w2);
D12 = _mm512_dpbusds_epi32(D12, s0, w3);
D13 = _mm512_dpbusds_epi32(D13, s1, w3);
D14 = _mm512_dpbusds_epi32(D14, s2, w3);
}
// int32_t -> float
auto scaleValue0 = _mm512_loadu_ps(scale_dz);
auto scaleValue1 = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT);
auto scaleValue2 = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT);
auto scaleValue3 = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(weightBias_dz);
auto weightBiasValue1 = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT);
auto weightBiasValue2 = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT);
auto weightBiasValue3 = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[2]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputscale2 = _mm512_set1_ps((post->inputScale + bk * realDst)[2]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
inputbias2 = _mm512_set1_ps((post->inputBias + bk * realDst)[2]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
MUL_WEIGHT_SCALE(2, 0);
MUL_WEIGHT_SCALE(4, 1);
MUL_WEIGHT_SCALE(5, 1);
MUL_WEIGHT_SCALE(6, 1);
MUL_WEIGHT_SCALE(8, 2);
MUL_WEIGHT_SCALE(9, 2);
MUL_WEIGHT_SCALE(10, 2);
MUL_WEIGHT_SCALE(12, 3);
MUL_WEIGHT_SCALE(13, 3);
MUL_WEIGHT_SCALE(14, 3);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
f2 = _mm512_mul_ps(f2, inputscale2);
f4 = _mm512_mul_ps(f4, inputscale0);
f5 = _mm512_mul_ps(f5, inputscale1);
f6 = _mm512_mul_ps(f6, inputscale2);
f8 = _mm512_mul_ps(f8, inputscale0);
f9 = _mm512_mul_ps(f9, inputscale1);
f10 = _mm512_mul_ps(f10, inputscale2);
f12 = _mm512_mul_ps(f12, inputscale0);
f13 = _mm512_mul_ps(f13, inputscale1);
f14 = _mm512_mul_ps(f14, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + bk * GEMMINT8_AVX512_H + dz * blockNum * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
bias02 = _mm512_mul_ps(inputbias2, wsum0);
bias10 = _mm512_mul_ps(inputbias0, wsum1);
bias11 = _mm512_mul_ps(inputbias1, wsum1);
bias12 = _mm512_mul_ps(inputbias2, wsum1);
bias20 = _mm512_mul_ps(inputbias0, wsum2);
bias21 = _mm512_mul_ps(inputbias1, wsum2);
bias22 = _mm512_mul_ps(inputbias2, wsum2);
bias30 = _mm512_mul_ps(inputbias0, wsum3);
bias31 = _mm512_mul_ps(inputbias1, wsum3);
bias32 = _mm512_mul_ps(inputbias2, wsum3);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dz * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
bias02 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum0);
bias10 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum1);
bias11 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum1);
bias12 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum1);
bias20 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum2);
bias21 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum2);
bias22 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum2);
bias30 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum3);
bias31 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum3);
bias32 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum3);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
f2 = _mm512_add_ps(f2, bias02);
f4 = _mm512_add_ps(f4, bias10);
f5 = _mm512_add_ps(f5, bias11);
f6 = _mm512_add_ps(f6, bias12);
f8 = _mm512_add_ps(f8, bias20);
f9 = _mm512_add_ps(f9, bias21);
f10 = _mm512_add_ps(f10, bias22);
f12 = _mm512_add_ps(f12, bias30);
f13 = _mm512_add_ps(f13, bias31);
f14 = _mm512_add_ps(f14, bias32);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
f2 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue0), f2);
f4 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue1), f4);
f5 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue1), f5);
f6 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue1), f6);
f8 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue2), f8);
f9 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue2), f9);
f10 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue2),f10);
f12 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue3),f12);
f13 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue3),f13);
f14 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue3),f14);
if (post->useInt8 == 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f6 = _mm512_add_ps(f6, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f10 = _mm512_add_ps(f10, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
f14 = _mm512_add_ps(f14, biasValue12);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
POSTTREAT(2, 2);
dst_x += dst_step_tmp;
POSTTREAT(4, 0);
POSTTREAT(5, 1);
POSTTREAT(6, 2);
dst_x += dst_step_tmp;
POSTTREAT(8, 0);
POSTTREAT(9, 1);
POSTTREAT(10, 2);
dst_x += dst_step_tmp;
POSTTREAT(12, 0);
POSTTREAT(13, 1);
POSTTREAT(14, 2);
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps(accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16), f1);
f2 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16 * 2), f2);
f4 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step), f4);
f5 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 1), f5);
f6 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 2), f6);
f8 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step), f8);
f9 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 1), f9);
f10 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 2), f10);
f12 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step), f12);
f13 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 1), f13);
f14 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 2), f14);
}
if (bk == blockNum - 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f6 = _mm512_add_ps(f6, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f10 = _mm512_add_ps(f10, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
f14 = _mm512_add_ps(f14, biasValue12);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_3(0,1,2);
POST_TREAT_FLOAT_3(4,5,6);
POST_TREAT_FLOAT_3(8,9,10);
POST_TREAT_FLOAT_3(12,13,14);
}
_mm512_storeu_ps(((float*)dst_x), f0);
_mm512_storeu_ps(((float*)dst_x) + 16, f1);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13);
_mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14);
} else {
_mm512_storeu_ps(accum_x, f0);
_mm512_storeu_ps(accum_x + 16, f1);
_mm512_storeu_ps(accum_x + 16 * 2, f2);
_mm512_storeu_ps(accum_x + source_step, f4);
_mm512_storeu_ps(accum_x + source_step + 16 * 1, f5);
_mm512_storeu_ps(accum_x + source_step + 16 * 2, f6);
_mm512_storeu_ps(accum_x + 2 * source_step, f8);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 1, f9);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 2, f10);
_mm512_storeu_ps(accum_x + 3 * source_step, f12);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 1, f13);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 2, f14);
}
}
} // dzU
// the remaining ocDivPack
auto weight_dz = weight + dzU * blockNum * weight_step_Z; // weight address for remaining
if (biasPtr) {
bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dzU * dst_step_tmp * dzUnit;
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weightDzSub + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
auto s2 = AVX512_BROADCAST_INT32(src_z + 2);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
D2 = _mm512_dpbusds_epi32(D2, s2, w0);
}
auto scaleValue0 = _mm512_loadu_ps(scaleDz + i * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(biasDz + i * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
kernelSum2 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[2]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputscale2 = _mm512_set1_ps((post->inputScale + bk * realDst)[2]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
inputbias2 = _mm512_set1_ps((post->inputBias + bk * realDst)[2]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
MUL_WEIGHT_SCALE(2, 0);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
f2 = _mm512_mul_ps(f2, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
bias02 = _mm512_mul_ps(inputbias2, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dzU * PACK_UNIT * dzUnit;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
bias02 = _mm512_mul_ps(_mm512_mul_ps(inputscale2, neg128f), wsum0);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
f2 = _mm512_add_ps(f2, bias02);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
f2 = _mm512_add_ps(_mm512_mul_ps(kernelSum2, weightBiasValue0), f2);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
POSTTREAT(2, 2);
dst_x += dst_step_tmp;
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps((float*)accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16), f1);
f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16 * 2), f2);
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f2 = _mm512_add_ps(f2, biasValue0);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_3(0,1,2);
}
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp), f0);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16, f1);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16 * 2, f2);
} else {
_mm512_storeu_ps(((float*)accum_x), f0);
_mm512_storeu_ps(((float*)accum_x) + 16, f1);
_mm512_storeu_ps(((float*)accum_x) + 16 * 2, f2);
}
}
}
return;
}
if (realDst == 2) {
for (int dz = 0; dz < dzU; ++dz) {
if (biasPtr) {
bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dz * dst_step_tmp * dzUnit;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D4 = _mm512_set1_epi32(0);
__m512i D5 = _mm512_set1_epi32(0);
__m512i D8 = _mm512_set1_epi32(0);
__m512i D9 = _mm512_set1_epi32(0);
__m512i D12 = _mm512_set1_epi32(0);
__m512i D13 = _mm512_set1_epi32(0);
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX512_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
D4 = _mm512_dpbusds_epi32(D4, s0, w1);
D5 = _mm512_dpbusds_epi32(D5, s1, w1);
D8 = _mm512_dpbusds_epi32(D8, s0, w2);
D9 = _mm512_dpbusds_epi32(D9, s1, w2);
D12 = _mm512_dpbusds_epi32(D12, s0, w3);
D13 = _mm512_dpbusds_epi32(D13, s1, w3);
}
// int32_t -> float
auto scaleValue0 = _mm512_loadu_ps(scale_dz);
auto scaleValue1 = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT);
auto scaleValue2 = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT);
auto scaleValue3 = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(weightBias_dz);
auto weightBiasValue1 = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT);
auto weightBiasValue2 = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT);
auto weightBiasValue3 = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
MUL_WEIGHT_SCALE(4, 1);
MUL_WEIGHT_SCALE(5, 1);
MUL_WEIGHT_SCALE(8, 2);
MUL_WEIGHT_SCALE(9, 2);
MUL_WEIGHT_SCALE(12, 3);
MUL_WEIGHT_SCALE(13, 3);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
f4 = _mm512_mul_ps(f4, inputscale0);
f5 = _mm512_mul_ps(f5, inputscale1);
f8 = _mm512_mul_ps(f8, inputscale0);
f9 = _mm512_mul_ps(f9, inputscale1);
f12 = _mm512_mul_ps(f12, inputscale0);
f13 = _mm512_mul_ps(f13, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + bk * GEMMINT8_AVX512_H + dz * blockNum * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
bias10 = _mm512_mul_ps(inputbias0, wsum1);
bias11 = _mm512_mul_ps(inputbias1, wsum1);
bias20 = _mm512_mul_ps(inputbias0, wsum2);
bias21 = _mm512_mul_ps(inputbias1, wsum2);
bias30 = _mm512_mul_ps(inputbias0, wsum3);
bias31 = _mm512_mul_ps(inputbias1, wsum3);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dz * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
bias10 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum1);
bias11 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum1);
bias20 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum2);
bias21 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum2);
bias30 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum3);
bias31 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum3);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
f4 = _mm512_add_ps(f4, bias10);
f5 = _mm512_add_ps(f5, bias11);
f8 = _mm512_add_ps(f8, bias20);
f9 = _mm512_add_ps(f9, bias21);
f12 = _mm512_add_ps(f12, bias30);
f13 = _mm512_add_ps(f13, bias31);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
f4 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue1), f4);
f5 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue1), f5);
f8 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue2), f8);
f9 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue2), f9);
f12 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue3),f12);
f13 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue3),f13);
if (post->useInt8 == 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
dst_x += dst_step_tmp;
POSTTREAT(4, 0);
POSTTREAT(5, 1);
dst_x += dst_step_tmp;
POSTTREAT(8, 0);
POSTTREAT(9, 1);
dst_x += dst_step_tmp;
POSTTREAT(12, 0);
POSTTREAT(13, 1);
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps(accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 16), f1);
f4 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step), f4);
f5 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step + 16 * 1), f5);
f8 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step), f8);
f9 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step + 16 * 1), f9);
f12 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step), f12);
f13 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step + 16 * 1), f13);
}
if (bk == blockNum - 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f5 = _mm512_add_ps(f5, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f9 = _mm512_add_ps(f9, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
f13 = _mm512_add_ps(f13, biasValue12);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_2(0,1);
POST_TREAT_FLOAT_2(4,5);
POST_TREAT_FLOAT_2(8,9);
POST_TREAT_FLOAT_2(12,13);
}
_mm512_storeu_ps(((float*)dst_x), f0);
_mm512_storeu_ps(((float*)dst_x) + 16, f1);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9);
dst_x += dst_step_tmp;
_mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12);
_mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13);
} else {
_mm512_storeu_ps(accum_x, f0);
_mm512_storeu_ps(accum_x + 16, f1);
_mm512_storeu_ps(accum_x + source_step, f4);
_mm512_storeu_ps(accum_x + source_step + 16 * 1, f5);
_mm512_storeu_ps(accum_x + 2 * source_step, f8);
_mm512_storeu_ps(accum_x + 2 * source_step + 16 * 1, f9);
_mm512_storeu_ps(accum_x + 3 * source_step, f12);
_mm512_storeu_ps(accum_x + 3 * source_step + 16 * 1, f13);
}
}
} // dzU
// the remaining ocDivPack
auto weight_dz = weight + dzU * blockNum * weight_step_Z; // weight address for remaining
if (biasPtr) {
bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dzU * dst_step_tmp * dzUnit;
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weightDzSub + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
auto s1 = AVX512_BROADCAST_INT32(src_z + 1);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D1 = _mm512_dpbusds_epi32(D1, s1, w0);
}
auto scaleValue0 = _mm512_loadu_ps(scaleDz + i * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(biasDz + i * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
kernelSum1 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[1]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputscale1 = _mm512_set1_ps((post->inputScale + bk * realDst)[1]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
inputbias1 = _mm512_set1_ps((post->inputBias + bk * realDst)[1]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(1, 0);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f1 = _mm512_mul_ps(f1, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dzU * PACK_UNIT * dzUnit;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale1, neg128f), wsum0);
}
f0 = _mm512_add_ps(f0, bias00);
f1 = _mm512_add_ps(f1, bias01);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f1 = _mm512_add_ps(_mm512_mul_ps(kernelSum1, weightBiasValue0), f1);
if (post->useInt8 == 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
}
POSTTREAT(0, 0);
POSTTREAT(1, 1);
dst_x += dst_step_tmp;
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps((float*)accum_x), f0);
f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)accum_x) + 16), f1);
}
if (bk == blockNum - 1) {
if (nullptr != biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f1 = _mm512_add_ps(f1, biasValue0);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_2(0,1);
}
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp), f0);
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp) + 16, f1);
} else {
_mm512_storeu_ps(((float*)accum_x), f0);
_mm512_storeu_ps(((float*)accum_x) + 16, f1);
}
}
}
return;
}
if (realDst == 1) {
for (int dz = 0; dz < dzU; ++dz) {
if (biasPtr) {
bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dz * dst_step_tmp * dzUnit;
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D4 = _mm512_set1_epi32(0);
__m512i D8 = _mm512_set1_epi32(0);
__m512i D12 = _mm512_set1_epi32(0);
// block's weight&scale&bias
const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z;
const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = scale_dz + GEMMINT8_AVX512_H;
// block's input
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weight_dz + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L);
auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
D4 = _mm512_dpbusds_epi32(D4, s0, w1);
D8 = _mm512_dpbusds_epi32(D8, s0, w2);
D12 = _mm512_dpbusds_epi32(D12, s0, w3);
}
// int32_t -> float
auto scaleValue0 = _mm512_loadu_ps(scale_dz);
auto scaleValue1 = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT);
auto scaleValue2 = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT);
auto scaleValue3 = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(weightBias_dz);
auto weightBiasValue1 = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT);
auto weightBiasValue2 = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT);
auto weightBiasValue3 = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
}
MUL_WEIGHT_SCALE(0, 0);
MUL_WEIGHT_SCALE(4, 1);
MUL_WEIGHT_SCALE(8, 2);
MUL_WEIGHT_SCALE(12, 3);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
f4 = _mm512_mul_ps(f4, inputscale0);
f8 = _mm512_mul_ps(f8, inputscale0);
f12 = _mm512_mul_ps(f12, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + bk * GEMMINT8_AVX512_H + dz * blockNum * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias0, wsum1);
bias02 = _mm512_mul_ps(inputbias0, wsum2);
bias03 = _mm512_mul_ps(inputbias0, wsum3);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dz * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz);
auto wsum1 = _mm512_loadu_ps(weightKernelSum_dz + 1 * PACK_UNIT);
auto wsum2 = _mm512_loadu_ps(weightKernelSum_dz + 2 * PACK_UNIT);
auto wsum3 = _mm512_loadu_ps(weightKernelSum_dz + 3 * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
bias01 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum1);
bias02 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum2);
bias03 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum3);
}
f0 = _mm512_add_ps(f0, bias00);
f4 = _mm512_add_ps(f4, bias01);
f8 = _mm512_add_ps(f8, bias02);
f12 = _mm512_add_ps(f12, bias03);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
f4 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue1), f4);
f8 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue2), f8);
f12 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue3), f12);
if (post->useInt8 == 1) {
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
}
POSTTREAT(0, 0);
dst_x += dst_step_tmp;
POSTTREAT(4, 0);
dst_x += dst_step_tmp;
POSTTREAT(8, 0);
dst_x += dst_step_tmp;
POSTTREAT(12, 0);
continue;
}
if (bk > 0) { // Add accumbuffer if blockNum>1
f0 = _mm512_add_ps(_mm512_loadu_ps(accum_x), f0);
f4 = _mm512_add_ps(_mm512_loadu_ps(accum_x + source_step), f4);
f8 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 2 * source_step), f8);
f12 = _mm512_add_ps(_mm512_loadu_ps(accum_x + 3 * source_step), f12);
}
if (bk == blockNum - 1) { // If last block, post process before saving to dest address.
if (biasPtr) {
auto biasValue0 = _mm512_loadu_ps(bias_dz);
auto biasValue4 = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT);
auto biasValue8 = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT);
auto biasValue12 = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT);
f0 = _mm512_add_ps(f0, biasValue0);
f4 = _mm512_add_ps(f4, biasValue4);
f8 = _mm512_add_ps(f8, biasValue8);
f12 = _mm512_add_ps(f12, biasValue12);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_1(0);
POST_TREAT_FLOAT_1(4);
POST_TREAT_FLOAT_1(8);
POST_TREAT_FLOAT_1(12);
}
_mm512_storeu_ps((float*)dst_x, f0);
_mm512_storeu_ps((float*)(dst_x + dst_step_tmp), f4);
_mm512_storeu_ps((float*)(dst_x + 2 * dst_step_tmp), f8);
_mm512_storeu_ps((float*)(dst_x + 3 * dst_step_tmp), f12);
} else { // save to accumbuffer to added to next block
_mm512_storeu_ps(accum_x, f0);
_mm512_storeu_ps(accum_x + source_step, f4);
_mm512_storeu_ps(accum_x + 2 * source_step, f8);
_mm512_storeu_ps(accum_x + 3 * source_step, f12);
}
}
}
// the remaining ocDivPack
auto weight_dz = weight + dzU * blockNum * weight_step_Z; // weight address for remaining
if (biasPtr) {
bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit;
}
auto dst_x = dst + dzU * dst_step_tmp * dzUnit;
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
__m512i D0 = _mm512_set1_epi32(0);
for (int sz = 0; sz < src_depth_quad; ++sz) {
const auto weight_sz = weightDzSub + weight_step_Y * sz;
const auto src_z = (const float*)(src_x + sz * realDst * GEMMINT8_AVX512_L);
auto w0 = _mm512_loadu_si512(weight_sz);
auto s0 = AVX512_BROADCAST_INT32(src_z + 0);
D0 = _mm512_dpbusds_epi32(D0, s0, w0);
}
auto scaleValue0 = _mm512_loadu_ps(scaleDz + i * PACK_UNIT);
auto weightBiasValue0 = _mm512_loadu_ps(biasDz + i * PACK_UNIT);
// input info
kernelSum0 = _mm512_set1_ps((post->srcKernelSum + bk * realDst)[0]);
if (post->inputBias) {
inputscale0 = _mm512_set1_ps((post->inputScale + bk * realDst)[0]);
inputbias0 = _mm512_set1_ps((post->inputBias + bk * realDst)[0]);
}
MUL_WEIGHT_SCALE(0, 0);
if (post->inputScale) { // Batch quant
f0 = _mm512_mul_ps(f0, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
weightKernelSum_dz = post->weightKernelSum + dzU * PACK_UNIT * dzUnit;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(_mm512_mul_ps(inputscale0, neg128f), wsum0);
}
f0 = _mm512_add_ps(f0, bias00);
}
}
f0 = _mm512_add_ps(_mm512_mul_ps(kernelSum0, weightBiasValue0), f0);
if (post->useInt8 == 1) {
if (biasPtr) {
auto biasValue = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
SCALE_BIAS_VEC(0);
}
POSTTREAT(0, 0);
dst_x += dst_step_tmp;
continue;
}
if (bk > 0) {
f0 = _mm512_add_ps(_mm512_loadu_ps(accum_x), f0);
}
if (bk == blockNum - 1) {
if (biasPtr) {
auto biasValue = _mm512_loadu_ps(bias_dz + i * PACK_UNIT);
SCALE_BIAS_VEC(0);
}
if (post->fp32minmax) {
POST_TREAT_FLOAT_1(0);
}
_mm512_storeu_ps((float*)(dst_x + i * dst_step_tmp), f0);
} else {
_mm512_storeu_ps(((float*)accum_x), f0);
}
}
}
return;
}
}