source/backend/cpu/compute/Int8FunctionsOpt.cpp (2,151 lines of code) (raw):

// // Int8FunctionsOpt.cpp // MNN // // Created by MNN on 2018/08/15. // Copyright © 2018, Alibaba Group Holding Limited // #include <math.h> #include <cstring> // for memset #include "Int8FunctionsOpt.h" #include "core/Macro.h" #include "core/CommonCompute.hpp" #include "CommonOptFunction.h" #include "math/Vec.hpp" #ifdef MNN_USE_NEON #include <arm_neon.h> extern "C" { void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx); void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor); void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack); #if defined(__aarch64__) // aarch32 sdot workaround void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); #if defined(MNN_LOW_MEMORY) // int4 weight gemmInt8 kernel void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); // Tools to dynamic-quant fp16-input data. #ifdef MNN_USE_ARMV82 void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec); // int8 weight gemmInt8 kernel to return fp16-output data. void MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, ssize_t aMax, const float* zeroPoint, size_t ocQuad, size_t offset); #endif #endif #endif // __aarch64__ } #endif // MNN_USE_NEON /* layout should be optimized for int8 source: source matrix is h x l transpose: if false, export compressed matrix as h x l, other export as l x h. */ void MNNPackForSparseQuantMatMul_B(int8_t* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const int8_t* source, size_t h, size_t kernelCount, size_t icCount, const int eP) { // 1. in quant convolution, source B layout is OC x (IC * KH * KW), // the dest layout of weight is BCSC(block compressed sparse colum) format, which is OC(!=0) x (KH*KW*IC!=0), as a canceled result, just do BCSR // 2. IC would be switched into the last dim. // BCSC int columOffset = 0; int i = 0; auto subSource = source; size_t l = kernelCount * icCount; for (; i + sparseBlockOC <= h; i += sparseBlockOC) { *NNZMap = 0; for(int ik = 0; ik < kernelCount; ik += 1) { auto kernelSource = subSource + ik; for(int ic = 0; ic < icCount; ic += 1) { if (!MNN::CommonCompute::checkAllZeros(kernelSource, l, sparseBlockOC, 1)) { for (int ioc = 0; ioc < sparseBlockOC; ioc++) { *dest = *(kernelSource + ioc * l); dest++; } *NNZMap = *NNZMap + 1; *dataOffsetMap = columOffset; dataOffsetMap++; columOffset = 0; } columOffset += eP; kernelSource += kernelCount; } } NNZMap++; columOffset -= l * eP; subSource += sparseBlockOC * l; } for (; i < h; i++) { *NNZMap = 0; for(int ik = 0; ik < kernelCount; ik += 1) { auto kernelSource = subSource + ik; for(int ic = 0; ic < icCount; ic += 1) { if (*kernelSource != 0) { *dest = *kernelSource; dest++; *NNZMap = *NNZMap + 1; *dataOffsetMap = columOffset; dataOffsetMap++; columOffset = 0; } columOffset += eP; kernelSource += kernelCount; } } NNZMap++; columOffset -= l * eP; subSource += l; } *dataOffsetMap = columOffset; // return; } void MNNGetSparseQuantMatMulPackMode(int* eP, int *lP, int* hP) { #if defined(__arm__) && !defined(__aarch64__) *eP = 8; #else *eP = 16; #endif *lP = 1; *hP = 4; // hp is corresponding to sparse block along right matrix colum dimension. in ramdom sparse, it is 1. return; } static void _MNNPackC4Int8ForMatMul_ASparse(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) { int number = info[0]; int eReal = info[1]; int eDest = info[2]; int offset = info[3]; for (int n=0; n<number; ++n) { int e = el[4 * n + 0]; int l = el[4 * n + 1]; int eOffset = el[4 * n + 2]; int lOffset = el[4 * n + 3]; auto dest = destOrigin + lOffset * eDest + eOffset; auto source = sourceGroup[n]; for (int y=0; y<e; ++y) { auto yR = y % eDest; for (int x=0; x<l; ++x) { auto xR = x % 4; auto xC = x / 4; dest[(x) * eDest + yR] = source[xC * eReal * 4 + y * 4 * offset + xR]; } } } } #ifndef MNN_USE_NEON void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) { size_t eSize = sparseQuantParam[0]; size_t eP = sparseQuantParam[1]; size_t aStride = sparseQuantParam[2]; size_t l = sparseQuantParam[3]; size_t h = sparseQuantParam[4]; size_t cStride = sparseQuantParam[5]; const int32_t* bias = post->bias; const float* scales = post->scale; const int32_t maxValue = post->maxValue; const int32_t minValue = post->minValue; const int sparseBlockOC = 4; const int8_t * a = A; size_t ie = 0; for (ie = 0; ie < eSize && eP <= eSize; ie += eP) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; a += diff; const int8_t * w = B; int8_t * blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; for (size_t ih = 0; ih < h; ih++) { auto ihPack = ih >> 2; auto ihSubIndex = ih & 0x03; auto c = blockC + ihPack * cStride + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; int32_t acc4 = initValue; int32_t acc5 = initValue; int32_t acc6 = initValue; int32_t acc7 = initValue; int32_t acc8 = initValue; int32_t acc9 = initValue; int32_t acc10 = initValue; int32_t acc11 = initValue; int32_t acc12 = initValue; int32_t acc13 = initValue; int32_t acc14 = initValue; int32_t acc15 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t a8 = a[8]; const int8_t a9 = a[9]; const int8_t a10 = a[10]; const int8_t a11 = a[11]; const int8_t a12 = a[12]; const int8_t a13 = a[13]; const int8_t a14 = a[14]; const int8_t a15 = a[15]; const int8_t oneW = *w++; // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {16}); // MNN_PRINT("\n"); a = a + diff; acc0 += (int32_t)a0 * (int32_t)oneW; acc1 += (int32_t)a1 * (int32_t)oneW; acc2 += (int32_t)a2 * (int32_t)oneW; acc3 += (int32_t)a3 * (int32_t)oneW; acc4 += (int32_t)a4 * (int32_t)oneW; acc5 += (int32_t)a5 * (int32_t)oneW; acc6 += (int32_t)a6 * (int32_t)oneW; acc7 += (int32_t)a7 * (int32_t)oneW; acc8 += (int32_t)a8 * (int32_t)oneW; acc9 += (int32_t)a9 * (int32_t)oneW; acc10 += (int32_t)a10 * (int32_t)oneW; acc11 += (int32_t)a11 * (int32_t)oneW; acc12 += (int32_t)a12 * (int32_t)oneW; acc13 += (int32_t)a13 * (int32_t)oneW; acc14 += (int32_t)a14 * (int32_t)oneW; acc15 += (int32_t)a15 * (int32_t)oneW; } int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit int8_t result1; int8_t result2; int8_t result3; int8_t result4; int8_t result5; int8_t result6; int8_t result7; int8_t result8; int8_t result9; int8_t result10; int8_t result11; int8_t result12; int8_t result13; int8_t result14; int8_t result15; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue)))); result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue)))); result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue)))); result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue)))); result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue)))); result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue)))); result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue)))); result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue)))); result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue)))); result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue)))); result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue)))); result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue)); result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue)); result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue)); result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue)); result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue)); result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue)); result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue)); result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue)); result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue)); result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue)); result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue)); result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; c[4 * 4] = result4; c[4 * 5] = result5; c[4 * 6] = result6; c[4 * 7] = result7; c[4 * 8] = result8; c[4 * 9] = result9; c[4 * 10] = result10; c[4 * 11] = result11; c[4 * 12] = result12; c[4 * 13] = result13; c[4 * 14] = result14; c[4 * 15] = result15; } a += aStride; } if (eSize & 0x08) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; for (size_t ih = 0; ih < h; ih++) { auto ihPack = ih >> 2; auto ihSubIndex = ih & 0x03; auto c = blockC + ihPack * cStride + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; int32_t acc4 = initValue; int32_t acc5 = initValue; int32_t acc6 = initValue; int32_t acc7 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t oneW = *w++; // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {8}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); acc2 += int32_t(a2) * int32_t(oneW); acc3 += int32_t(a3) * int32_t(oneW); acc4 += int32_t(a4) * int32_t(oneW); acc5 += int32_t(a5) * int32_t(oneW); acc6 += int32_t(a6) * int32_t(oneW); acc7 += int32_t(a7) * int32_t(oneW); } int8_t result0; int8_t result1; int8_t result2; int8_t result3; int8_t result4; int8_t result5; int8_t result6; int8_t result7; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue)))); result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue)))); result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue)))); result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue)); result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue)); result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue)); result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; c[4 * 4] = result4; c[4 * 5] = result5; c[4 * 6] = result6; c[4 * 7] = result7; } ie += 8; a += 8; } if (eSize & 0x04) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; for (size_t ih = 0; ih < h; ih++) { auto ihPack = ih >> 2; auto ihSubIndex = ih & 0x03; auto c = blockC + ihPack * cStride + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t oneW = *w++; // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {4}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); acc2 += int32_t(a2) * int32_t(oneW); acc3 += int32_t(a3) * int32_t(oneW); } int8_t result0; int8_t result1; int8_t result2; int8_t result3; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; } ie += 4; a += 4; } if (eSize & 0x02) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; for (size_t ih = 0; ih < h; ih++) { auto ihPack = ih >> 2; auto ihSubIndex = ih & 0x03; auto c = blockC + ihPack * cStride + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t oneW = *w++; // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {2}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); } int8_t result0; int8_t result1; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; } ie += 2; a += 2; } if (eSize & 0x01) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // const float* a = blockA + diff; a += diff; const int8_t * w = B; int8_t * blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; for (size_t ih = 0; ih < h; ih++) { auto ihPack = ih >> 2; auto ihSubIndex = ih & 0x03; auto c = blockC + ihPack * cStride + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t oneW = *w++; // MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {1}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); } int8_t result0; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; } ie += 1; // a += 1; } } void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) { size_t eSize = sparseQuantParam[0]; size_t eP = sparseQuantParam[1]; size_t aStride = sparseQuantParam[2]; size_t l = sparseQuantParam[3]; size_t h = sparseQuantParam[4]; size_t cStride = sparseQuantParam[5]; const int32_t* bias = post->bias; const float* scales = post->scale; const int32_t maxValue = post->maxValue; const int32_t minValue = post->minValue; const int sparseBlockOC = 4; const int8_t * a = A; size_t ie = 0; for (ie = 0; ie < eSize && eP <= eSize; ie += eP) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; a += diff; const int8_t * w = B; int8_t * blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; size_t ih = 0; for (; ih < (h & (~0x03)); ih += sparseBlockOC) { auto ihPack = ih >> 2; auto c = blockC + ihPack * cStride; int32_t initValue[4] = {0, 0, 0, 0}; if (nullptr != bias) { memcpy(initValue, bias + ih, 4 * sizeof(int32_t)); } int32_t acc0[4]; int32_t acc1[4]; int32_t acc2[4]; int32_t acc3[4]; int32_t acc4[4]; int32_t acc5[4]; int32_t acc6[4]; int32_t acc7[4]; int32_t acc8[4]; int32_t acc9[4]; int32_t acc10[4]; int32_t acc11[4]; int32_t acc12[4]; int32_t acc13[4]; int32_t acc14[4]; int32_t acc15[4]; memcpy(acc0, initValue, 4 * sizeof(int32_t)); memcpy(acc1, initValue, 4 * sizeof(int32_t)); memcpy(acc2, initValue, 4 * sizeof(int32_t)); memcpy(acc3, initValue, 4 * sizeof(int32_t)); memcpy(acc4, initValue, 4 * sizeof(int32_t)); memcpy(acc5, initValue, 4 * sizeof(int32_t)); memcpy(acc6, initValue, 4 * sizeof(int32_t)); memcpy(acc7, initValue, 4 * sizeof(int32_t)); memcpy(acc8, initValue, 4 * sizeof(int32_t)); memcpy(acc9, initValue, 4 * sizeof(int32_t)); memcpy(acc10, initValue, 4 * sizeof(int32_t)); memcpy(acc11, initValue, 4 * sizeof(int32_t)); memcpy(acc12, initValue, 4 * sizeof(int32_t)); memcpy(acc13, initValue, 4 * sizeof(int32_t)); memcpy(acc14, initValue, 4 * sizeof(int32_t)); memcpy(acc15, initValue, 4 * sizeof(int32_t)); const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t a8 = a[8]; const int8_t a9 = a[9]; const int8_t a10 = a[10]; const int8_t a11 = a[11]; const int8_t a12 = a[12]; const int8_t a13 = a[13]; const int8_t a14 = a[14]; const int8_t a15 = a[15]; const int8_t wv[4] = {*w++, *w++, *w++, *w++}; // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {16}); // MNN_PRINT("\n"); a = a + diff; for (int lane = 0; lane < 4; lane++) { acc0[lane] += (int32_t)a0 * (int32_t)wv[lane]; acc1[lane] += (int32_t)a1 * (int32_t)wv[lane]; acc2[lane] += (int32_t)a2 * (int32_t)wv[lane]; acc3[lane] += (int32_t)a3 * (int32_t)wv[lane]; acc4[lane] += (int32_t)a4 * (int32_t)wv[lane]; acc5[lane] += (int32_t)a5 * (int32_t)wv[lane]; acc6[lane] += (int32_t)a6 * (int32_t)wv[lane]; acc7[lane] += (int32_t)a7 * (int32_t)wv[lane]; acc8[lane] += (int32_t)a8 * (int32_t)wv[lane]; acc9[lane] += (int32_t)a9 * (int32_t)wv[lane]; acc10[lane] += (int32_t)a10 * (int32_t)wv[lane]; acc11[lane] += (int32_t)a11 * (int32_t)wv[lane]; acc12[lane] += (int32_t)a12 * (int32_t)wv[lane]; acc13[lane] += (int32_t)a13 * (int32_t)wv[lane]; acc14[lane] += (int32_t)a14 * (int32_t)wv[lane]; acc15[lane] += (int32_t)a15 * (int32_t)wv[lane]; } } int8_t result0[4]; int8_t result1[4]; int8_t result2[4]; int8_t result3[4]; int8_t result4[4]; int8_t result5[4]; int8_t result6[4]; int8_t result7[4]; int8_t result8[4]; int8_t result9[4]; int8_t result10[4]; int8_t result11[4]; int8_t result12[4]; int8_t result13[4]; int8_t result14[4]; int8_t result15[4]; if (scales) { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue)))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue)))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue)))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue)))); result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue)))); result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue)))); result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue)))); result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue)))); result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc8[lane])), float(minValue)))); result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc9[lane])), float(minValue)))); result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc10[lane])), float(minValue)))); result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc11[lane])), float(minValue)))); result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc12[lane])), float(minValue)))); result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc13[lane])), float(minValue)))); result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc14[lane])), float(minValue)))); result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc15[lane])), float(minValue)))); } } else { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue))); result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue))); result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue))); result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue))); result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue))); result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc8[lane]), minValue))); result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc9[lane]), minValue))); result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc10[lane]), minValue))); result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc11[lane]), minValue))); result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc12[lane]), minValue))); result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc13[lane]), minValue))); result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc14[lane]), minValue))); result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc15[lane]), minValue))); } } memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c memcpy(c + 4 , result1, 4 * sizeof(int8_t)); memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t)); memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t)); memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t)); memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t)); memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t)); memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t)); memcpy(c + 4 * 8 , result8, 4 * sizeof(int8_t)); memcpy(c + 4 * 9 , result9, 4 * sizeof(int8_t)); memcpy(c + 4 * 10, result10, 4 * sizeof(int8_t)); memcpy(c + 4 * 11, result11, 4 * sizeof(int8_t)); memcpy(c + 4 * 12, result12, 4 * sizeof(int8_t)); memcpy(c + 4 * 13, result13, 4 * sizeof(int8_t)); memcpy(c + 4 * 14, result14, 4 * sizeof(int8_t)); memcpy(c + 4 * 15, result15, 4 * sizeof(int8_t)); } blockC += (h >> 2) * cStride; for (; ih < h; ih++) { auto ihSubIndex = ih & 0x03; auto c = blockC + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; int32_t acc4 = initValue; int32_t acc5 = initValue; int32_t acc6 = initValue; int32_t acc7 = initValue; int32_t acc8 = initValue; int32_t acc9 = initValue; int32_t acc10 = initValue; int32_t acc11 = initValue; int32_t acc12 = initValue; int32_t acc13 = initValue; int32_t acc14 = initValue; int32_t acc15 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t a8 = a[8]; const int8_t a9 = a[9]; const int8_t a10 = a[10]; const int8_t a11 = a[11]; const int8_t a12 = a[12]; const int8_t a13 = a[13]; const int8_t a14 = a[14]; const int8_t a15 = a[15]; const int8_t oneW = *w++; // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {16}); // MNN_PRINT("\n"); a = a + diff; acc0 += (int32_t)a0 * (int32_t)oneW; acc1 += (int32_t)a1 * (int32_t)oneW; acc2 += (int32_t)a2 * (int32_t)oneW; acc3 += (int32_t)a3 * (int32_t)oneW; acc4 += (int32_t)a4 * (int32_t)oneW; acc5 += (int32_t)a5 * (int32_t)oneW; acc6 += (int32_t)a6 * (int32_t)oneW; acc7 += (int32_t)a7 * (int32_t)oneW; acc8 += (int32_t)a8 * (int32_t)oneW; acc9 += (int32_t)a9 * (int32_t)oneW; acc10 += (int32_t)a10 * (int32_t)oneW; acc11 += (int32_t)a11 * (int32_t)oneW; acc12 += (int32_t)a12 * (int32_t)oneW; acc13 += (int32_t)a13 * (int32_t)oneW; acc14 += (int32_t)a14 * (int32_t)oneW; acc15 += (int32_t)a15 * (int32_t)oneW; } int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit int8_t result1; int8_t result2; int8_t result3; int8_t result4; int8_t result5; int8_t result6; int8_t result7; int8_t result8; int8_t result9; int8_t result10; int8_t result11; int8_t result12; int8_t result13; int8_t result14; int8_t result15; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue)))); result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue)))); result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue)))); result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue)))); result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue)))); result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue)))); result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue)))); result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue)))); result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue)))); result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue)))); result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue)))); result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue)); result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue)); result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue)); result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue)); result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue)); result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue)); result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue)); result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue)); result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue)); result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue)); result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue)); result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; c[4 * 4] = result4; c[4 * 5] = result5; c[4 * 6] = result6; c[4 * 7] = result7; c[4 * 8] = result8; c[4 * 9] = result9; c[4 * 10] = result10; c[4 * 11] = result11; c[4 * 12] = result12; c[4 * 13] = result13; c[4 * 14] = result14; c[4 * 15] = result15; } a += aStride; } if (eSize & 0x08) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; size_t ih = 0; for (; ih < (h & (~0x03)); ih += sparseBlockOC) { auto ihPack = ih >> 2; auto c = blockC + ihPack * cStride; int32_t initValue[4] = {0, 0, 0, 0}; if (nullptr != bias) { memcpy(initValue, bias + ih, 4 * sizeof(int32_t)); } int32_t acc0[4]; int32_t acc1[4]; int32_t acc2[4]; int32_t acc3[4]; int32_t acc4[4]; int32_t acc5[4]; int32_t acc6[4]; int32_t acc7[4]; memcpy(acc0, initValue, 4 * sizeof(int32_t)); memcpy(acc1, initValue, 4 * sizeof(int32_t)); memcpy(acc2, initValue, 4 * sizeof(int32_t)); memcpy(acc3, initValue, 4 * sizeof(int32_t)); memcpy(acc4, initValue, 4 * sizeof(int32_t)); memcpy(acc5, initValue, 4 * sizeof(int32_t)); memcpy(acc6, initValue, 4 * sizeof(int32_t)); memcpy(acc7, initValue, 4 * sizeof(int32_t)); const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t wv[4] = {*w++, *w++, *w++, *w++}; // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value[0-3]:, a value[0-7]:\n", ie, a - A, w - B - 1, c - C); // formatMatrix(wv, {4}); // formatMatrix(a, {8}); // MNN_PRINT("\n"); a = a + diff; for (int lane = 0; lane < 4; lane++) { acc0[lane] += int32_t(a0) * int32_t(wv[lane]); acc1[lane] += int32_t(a1) * int32_t(wv[lane]); acc2[lane] += int32_t(a2) * int32_t(wv[lane]); acc3[lane] += int32_t(a3) * int32_t(wv[lane]); acc4[lane] += int32_t(a4) * int32_t(wv[lane]); acc5[lane] += int32_t(a5) * int32_t(wv[lane]); acc6[lane] += int32_t(a6) * int32_t(wv[lane]); acc7[lane] += int32_t(a7) * int32_t(wv[lane]); } } int8_t result0[4]; int8_t result1[4]; int8_t result2[4]; int8_t result3[4]; int8_t result4[4]; int8_t result5[4]; int8_t result6[4]; int8_t result7[4]; if (scales) { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue)))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue)))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue)))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue)))); result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue)))); result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue)))); result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue)))); result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue)))); } } else { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue))); result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue))); result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue))); result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue))); result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue))); } } memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c memcpy(c + 4 , result1, 4 * sizeof(int8_t)); memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t)); memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t)); memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t)); memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t)); memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t)); memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t)); } blockC += (ih >> 2) * cStride; for (; ih < h; ih++) { auto ihSubIndex = ih & 0x03; auto c = blockC + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; int32_t acc4 = initValue; int32_t acc5 = initValue; int32_t acc6 = initValue; int32_t acc7 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t a4 = a[4]; const int8_t a5 = a[5]; const int8_t a6 = a[6]; const int8_t a7 = a[7]; const int8_t oneW = *w++; // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {8}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); acc2 += int32_t(a2) * int32_t(oneW); acc3 += int32_t(a3) * int32_t(oneW); acc4 += int32_t(a4) * int32_t(oneW); acc5 += int32_t(a5) * int32_t(oneW); acc6 += int32_t(a6) * int32_t(oneW); acc7 += int32_t(a7) * int32_t(oneW); } int8_t result0; int8_t result1; int8_t result2; int8_t result3; int8_t result4; int8_t result5; int8_t result6; int8_t result7; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue)))); result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue)))); result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue)))); result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue)); result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue)); result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue)); result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; c[4 * 4] = result4; c[4 * 5] = result5; c[4 * 6] = result6; c[4 * 7] = result7; } ie += 8; a += 8; } if (eSize & 0x04) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; size_t ih = 0; for (; ih < (h & (~0x03)); ih += sparseBlockOC) { auto ihPack = ih >> 2; auto c = blockC + ihPack * cStride; int32_t initValue[4] = {0, 0, 0, 0}; if (nullptr != bias) { memcpy(initValue, bias + ih, 4 * sizeof(int32_t)); } int32_t acc0[4]; int32_t acc1[4]; int32_t acc2[4]; int32_t acc3[4]; memcpy(acc0, initValue, 4 * sizeof(int32_t)); memcpy(acc1, initValue, 4 * sizeof(int32_t)); memcpy(acc2, initValue, 4 * sizeof(int32_t)); memcpy(acc3, initValue, 4 * sizeof(int32_t)); const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t wv[4] = {*w++, *w++, *w++, *w++}; // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-3]:\n", ie, a - A, w - B - 1, c - C); // formatMatrix(wv, {4}); // formatMatrix(a, {4}); // MNN_PRINT("\n"); a = a + diff; for (int lane = 0; lane < 4; lane++) { acc0[lane] += int32_t(a0) * int32_t(wv[lane]); acc1[lane] += int32_t(a1) * int32_t(wv[lane]); acc2[lane] += int32_t(a2) * int32_t(wv[lane]); acc3[lane] += int32_t(a3) * int32_t(wv[lane]); } } int8_t result0[4]; int8_t result1[4]; int8_t result2[4]; int8_t result3[4]; if (scales) { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue)))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue)))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue)))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue)))); } } else { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue))); result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue))); result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue))); } } memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c memcpy(c + 4 , result1, 4 * sizeof(int8_t)); memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t)); memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t)); } blockC += (ih >> 2) * cStride; for (; ih < h; ih++) { auto ihSubIndex = ih & 0x03; auto c = blockC + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; int32_t acc2 = initValue; int32_t acc3 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t a2 = a[2]; const int8_t a3 = a[3]; const int8_t oneW = *w++; // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {4}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); acc2 += int32_t(a2) * int32_t(oneW); acc3 += int32_t(a3) * int32_t(oneW); } int8_t result0; int8_t result1; int8_t result2; int8_t result3; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue)))); result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue)); result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; c[4 * 2] = result2; c[4 * 3] = result3; } ie += 4; a += 4; } if (eSize & 0x02) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // a = blockA + diff; a += diff; const int8_t* w = B; int8_t* blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; size_t ih = 0; for (; ih < (h & (~0x03)); ih += sparseBlockOC) { auto ihPack = ih >> 2; auto c = blockC + ihPack * cStride; int32_t initValue[4] = {0, 0, 0, 0}; if (nullptr != bias) { memcpy(initValue, bias + ih, 4 * sizeof(int32_t)); } int32_t acc0[4]; int32_t acc1[4]; memcpy(acc0, initValue, 4 * sizeof(int32_t)); memcpy(acc1, initValue, 4 * sizeof(int32_t)); const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t wv[4] = {*w++, *w++, *w++, *w++}; // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C); // formatMatrix(wv, {4}); // formatMatrix(a, {2}); // MNN_PRINT("\n"); a = a + diff; for (int lane = 0; lane < 4; lane++) { acc0[lane] += int32_t(a0) * int32_t(wv[lane]); acc1[lane] += int32_t(a1) * int32_t(wv[lane]); } } int8_t result0[4]; int8_t result1[4]; if (scales) { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue)))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue)))); } } else { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue))); result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue))); } } memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c memcpy(c + 4 , result1, 4 * sizeof(int8_t)); } blockC += (ih >> 2) * cStride; for (; ih < h; ih++) { auto ihSubIndex = ih & 0x03; auto c = blockC + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; int32_t acc1 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t a1 = a[1]; const int8_t oneW = *w++; // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {2}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); acc1 += int32_t(a1) * int32_t(oneW); } int8_t result0; int8_t result1; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; c[4] = result1; } ie += 2; a += 2; } if (eSize & 0x01) { const int* dataOffset = dataOffsetMap; const int diff = *dataOffset++; // const float* a = blockA + diff; a += diff; const int8_t * w = B; int8_t * blockC = C + (ie << 2); const unsigned int* nnz = NNZMap; size_t ih = 0; for (; ih < (h & (~0x03)); ih += sparseBlockOC) { auto ihPack = ih >> 2; auto c = blockC + ihPack * cStride; int32_t initValue[4] = {0, 0, 0, 0}; if (nullptr != bias) { memcpy(initValue, bias + ih, 4 * sizeof(int32_t)); } int32_t acc0[4]; memcpy(acc0, initValue, 4 * sizeof(int32_t)); const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t wv[4] = {*w++, *w++, *w++, *w++}; // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C); // formatMatrix(wv, {4}); // formatMatrix(a, {16}); // MNN_PRINT("\n"); a = a + diff; for (int lane = 0; lane < 4; lane++) { acc0[lane] += int32_t(a0) * int32_t(wv[lane]); } } int8_t result0[4]; if (scales) { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue)))); } } else { for (int lane = 0; lane < 4; lane++) { result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue))); } } memcpy(c, result0, 4 * sizeof(int8_t)); // store continuous c } blockC += (ih >> 2) * cStride; for (; ih < h; ih++) { auto ihSubIndex = ih & 0x03; auto c = blockC + ihSubIndex; const int32_t initValue = nullptr != bias ? bias[ih] : 0; int32_t acc0 = initValue; const int lElement = *nnz++; for (auto il = 0; il < lElement; il++) { const int diff = *dataOffset++; const int8_t a0 = a[0]; const int8_t oneW = *w++; // MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW); // formatMatrix(a, {1}); // MNN_PRINT("\n"); a = a + diff; acc0 += int32_t(a0) * int32_t(oneW); } int8_t result0; if (scales) { result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue)))); } else { result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue)); } // how to store faster: st4 / transpose / c[0] = result0; } ie += 1; // a += 1; } } static int8_t MNNInt32ToInt8(int data, int bias, float scale, float maxValue, float minValue) { float value = (float)(data + bias) * scale; value = ALIMAX(value, minValue); value = ALIMIN(value, maxValue); return static_cast<int8_t>(roundf(value)); } static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { const int bytes = ((post->useInt8 == 1) ? 1 : 4); float fp32min = 0, fp32max = 0; int weight_step_Z = src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT; int weight_step_Y = (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); if (0 == post->useInt8 && post->fp32minmax) { fp32min = (post->fp32minmax)[0]; fp32max = (post->fp32minmax)[1]; } float* biasPtr = (float*)post->biasFloat; auto accumbuff = post->accumBuffer; auto blockNum = post->blockNum; for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_z = dst + dz * dst_step; auto accum_z = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT; const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto srcSumPtr = post->srcKernelSum + bk * realCount; for (int w = 0; w < realCount; ++w) { const auto src_x = src + bk * src_depth_quad * GEMM_INT8_SRC_UNIT * realCount + w * GEMM_INT8_SRC_UNIT; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; auto accum_x = accum_z + w * GEMM_INT8_UNIT; int32_t dstTemp[4] = {0, 0, 0, 0}; for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = weight_dz + weight_step_Y * sz; const auto src_z = src_x + sz * realCount * GEMM_INT8_SRC_UNIT; for (int j = 0; j < GEMM_INT8_UNIT; ++j) { const auto weight_j = weight_sz + j * GEMM_INT8_SRC_UNIT; for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) { dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i]; } } } for (int j = 0; j < GEMM_INT8_UNIT; ++j) { float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weightBias_dz[j]; if (post->inputScale) { value = dstTemp[j] * scale_dz[j] * (post->inputScale + bk * realCount)[w] + srcSumPtr[w] * weightBias_dz[j]; } if (post->inputBias) { auto weightKernelSum = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT; value += ((post->inputBias + bk * realCount)[w] * weightKernelSum[j]); } if (post->useInt8 == 0) { if (bk > 0) { float dstv = ((float*)accum_x)[j]; value += dstv; } if (bk == blockNum - 1) { if (biasPtr) { value += bias_dz[j]; } if (post->fp32minmax) { value = std::min(std::max(fp32min, value), fp32max); } ((float*)dst_x)[j] = value; } else { ((float*)accum_x)[j] = value; } } else { value += bias_dz[j]; value = ALIMAX(value, post->minValue); value = ALIMIN(value, post->maxValue); dst_x[j] = static_cast<int8_t>(roundf(value)); } } } } } } static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { uint32_t c = 0xf; const int bytes = 4; float fp32min = 0, fp32max = 0; int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); int weight_step_Z = weight_step_Y * src_depth_quad + 4 * 2 * GEMM_INT8_UNIT; MNN_ASSERT(post->useInt8==0); if (post->fp32minmax) { fp32min = (post->fp32minmax)[0]; fp32max = (post->fp32minmax)[1]; } float* biasPtr = (float*)post->biasFloat; auto accumbuff = post->accumBuffer; auto blockNum = post->blockNum; for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_z = dst + dz * dst_step; auto accum_z = accumbuff; for (int bk = 0; bk < blockNum; ++bk) { const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y); const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT; const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto srcSumPtr = post->srcKernelSum + bk * realCount; for (int w = 0; w < realCount; ++w) { const auto src_x = src + w * GEMM_INT8_SRC_UNIT; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; auto accum_x = accum_z + w * GEMM_INT8_UNIT; int32_t dstTemp[4] = {0, 0, 0, 0}; for (int sz = 0; sz < src_depth_quad; ++sz) { const auto weight_sz = (uint8_t*)weight_dz + weight_step_Y * sz; const auto src_z = src_x + sz * realCount * GEMM_INT8_SRC_UNIT; int w8[64]; // 64=GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT for (int k = 0; k < 32; ++k) { w8[k] = (weight_sz[k]>>4); w8[k + 32] = (weight_sz[k] & c); } for (int j = 0; j < GEMM_INT8_UNIT; ++j) { const auto weight_j = w8 + j * GEMM_INT8_SRC_UNIT; for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) { dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i]; } } } for (int j = 0; j < GEMM_INT8_UNIT; ++j) { float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weightBias_dz[j]; if (post->inputScale) { value = dstTemp[j] * scale_dz[j] * (post->inputScale + bk * realCount)[w] + srcSumPtr[w] * weightBias_dz[j]; } if (post->inputBias) { auto weightKernelSum = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT; value += ((post->inputBias + bk * realCount)[w] * weightKernelSum[j]); } if (bk > 0) { float dstv = ((float*)accum_x)[j]; value += dstv; } if (bk == blockNum - 1) { if (biasPtr) { value += bias_dz[j]; } if (post->fp32minmax) { value = std::min(std::max(fp32min, value), fp32max); } ((float*)dst_x)[j] = value; } else { ((float*)accum_x)[j] = value; } } } } } } static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack) { #ifdef MNN_USE_SSE float offset = 128.f; uint8_t* srcPtr = (uint8_t*)src; uint8_t* dstPtr = (uint8_t*)dst; #else float offset = 0.f; const int8_t* srcPtr = src; int8_t* dstPtr = dst; #endif float mulVal = 0.f; float inputZero = static_cast<float>(params->inputZeroPoint[0]) + offset; float outputZero = static_cast<float>(params->outputZeroPoint[0]) + offset; int32_t minval = params->minValue + offset; int32_t maxval = params->maxValue + offset; for (int j = 0;j < depthQuad; ++j) { const float* slopeZ = slope + pack * j; const auto srcZ = srcPtr + pack * j * planeNumber; auto dstZ = dstPtr + pack * j * planeNumber; for (int i = 0; i < planeNumber; ++i) { for (int c = 0; c < pack; ++c) { float valInput = (static_cast<float>(srcZ[pack * i + c]) - inputZero) * params->inputScale[0]; if (valInput < 0) { valInput *= slopeZ[c]; } auto mulVal = valInput * params->outputScale[0] + outputZero; dstZ[pack * i + c] = ALIMIN(ALIMAX(static_cast<int32_t>(roundf(mulVal)), minval), maxval); } } } } static void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { return MNNGemmInt8AddBiasScale_16x4_Unit(dst, src, weight, src_depth_quad, dst_step, dst_depth_quad, post, realCount); } static void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) { #ifdef MNN_USE_SSE int offset = 128; uint8_t* dstPtr = (uint8_t*)dst; const int16_t* srcPtr = (int16_t*)src; const int16_t* weightPtr = (int16_t*)weight; #else int offset = 0; int8_t* dstPtr = dst; const int8_t* srcPtr = src; const int8_t* weightPtr = weight; #endif int pack = 16; auto bias_z = parameters->bias; auto scale_z = parameters->scale; int dx, fx, fy; for (dx = 0; dx < width; ++dx) { auto dst_x = dstPtr + dx * pack; int32_t dstInt32[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; const auto src_z = srcPtr + src_w_step * dx; for (fy = 0; fy < fh; ++fy) { const auto src_y = src_z + fy * dilateY_step; const auto weight_y = weightPtr + fy * fw * pack; for (fx = 0; fx < fw; ++fx) { const auto src_x = src_y + fx * dilateX_step; const auto weight_x = weight_y + pack * fx; for (int j = 0; j < pack; ++j) { dstInt32[j] += static_cast<int32_t>(src_x[j]) * static_cast<int32_t>(weight_x[j]); } } } for (int i = 0; i < pack; ++i) { float val = (dstInt32[i] + bias_z[i]) * scale_z[i]; int valOut = roundf(val) + offset; if (valOut > parameters->maxValue + offset) { valOut = parameters->maxValue + offset; } if (valOut < parameters->minValue + offset) { valOut = parameters->minValue + offset; } dst_x[i] = static_cast<int>(valOut); } } } static void MNNLineDepthWiseInt8AddBiasScaleUnit3x3(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) { MNNLineDepthWiseInt8AddBiasScaleUnit(dst, src, weight, parameters, width, src_w_step, fw, fh, dilateX_step, dilateY_step, idxOrder); } #endif #ifndef MNN_USE_NEON void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec) { // quanParamVec: // 00: scale is vector // 10: zero is vector // 11: both are vector float scale4[4] = {scalep[0], scalep[0], scalep[0], scalep[0] }; float zero4[4] = {zeroPoint[0], zeroPoint[0], zeroPoint[0], zeroPoint[0]}; if (quanParamVec % 2 == 1) { scale4[0] = scalep[0]; scale4[1] = scalep[1]; scale4[2] = scalep[2]; scale4[3] = scalep[3]; } if (quanParamVec >> 1 == 1) { zero4[0] = zeroPoint[0]; zero4[1] = zeroPoint[1]; zero4[2] = zeroPoint[2]; zero4[3] = zeroPoint[3]; } for (int i = 0; i < sizeQuad; ++i) { for (int j=0; j<4; ++j) { int v = (int)roundf(src[4*i+j] * scale4[j]) + zero4[j]; if (v > maxValue) { v = maxValue; } if (v < minValue) { v = minValue; } dst[4*i+j] = v; } } } void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, const float* zeroPoint, ssize_t quantParamVec) { float scale_[4] = {scale[0], scale[0], scale[0], scale[0]}; float zero_[4] = {zeroPoint[0], zeroPoint[0], zeroPoint[0], zeroPoint[0]}; if (quantParamVec & 1) { ::memcpy(scale_, scale, 4 * sizeof(float)); } if (quantParamVec >> 1) { ::memcpy(zero_, zeroPoint, 4 * sizeof(float)); } for (int i = 0; i < size; ++i) { const auto srcStart = src + i * 4; auto dstStart = dst + i * 4; for (int j = 0; j < 4; ++j) { dstStart[j] = static_cast<float>(srcStart[j] - zero_[j]) * scale_[j]; } } } void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor) { int pack = 16; int8_t* dstPtr = dst; const int8_t* srcPtr = src; for (int ox = 0; ox < outputWidth; ++ox) { std::vector<int> sum_(pack, 0); for (int y = 0; y < kernely; ++y) { for (int x = 0; x < kernelx; ++x) { const int8_t *inputPtr = srcPtr + pack* (x + inputWidth* y); for (int idx = 0; idx < pack; ++idx) { sum_[idx] += *(inputPtr + idx); } } } for (int idx = 0; idx < pack; ++idx) { *(dstPtr + idx) = static_cast<int8_t>((sum_[idx] * factor)>>24); } dstPtr = dstPtr + pack; srcPtr = srcPtr + pack* stridesx; } } void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx) { int pack = 16; int8_t* dstPtr = dst; const int8_t* srcPtr = src; for (int ox = 0; ox < outputWidth; ++ox){ std::vector<int8_t> results(pack, INT8_MIN); for (int y = 0; y < kernely; ++y) { for (int x = 0; x < kernelx; ++x) { const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y); for (int idx = 0; idx < pack; ++idx) { results[idx] = std::max(results[idx], *(inputPtr + idx)); } } } for (int idx = 0; idx < pack;++idx) { *(dstPtr + idx) = results[idx]; } dstPtr = dstPtr + pack; srcPtr = srcPtr + pack* stridesx; } } void MNNBinaryAddInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { float sum = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { float inp0 = static_cast<int32_t>(inputData0[0] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[i] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); sum = inp0 + inp1; } else if (needBroadcast == 1) { float inp0 = static_cast<int32_t>(inputData0[i] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[0] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); sum = inp0 + inp1; } else { float inp0 = static_cast<int32_t>(inputData0[i] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[i] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); sum = inp0 + inp1; } int value = (int)roundf(sum * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNBinarySubInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { float inp0 = static_cast<int32_t>(inputData0[0] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[i] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); res = inp0 - inp1; } else if (needBroadcast == 1) { float inp0 = static_cast<int32_t>(inputData0[i] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[0] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); res = inp0 - inp1; } else { float inp0 = static_cast<int32_t>(inputData0[i] - offset - (int32_t)params->inputZeroPoint[0]) * static_cast<float>(inputScalesFp32[0]); float inp1 = static_cast<int32_t>(inputData1[i] - offset - (int32_t)params->inputZeroPoint[1]) * static_cast<float>(inputScalesFp32[1]); res = inp0 - inp1; } int value = (int)roundf(res * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNBinaryMulInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { float inp0 = (inputData0[0] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = inp0 * inp1; } else if (needBroadcast == 1) { float inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[0] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = inp0 * inp1; } else { float inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = inp0 * inp1; } int value = (int)roundf(res * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNBinaryMinInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { int res = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { int32_t inp0 = static_cast<int32_t>(inputData0[0] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[i] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::min(inp0, inp1); } else if (needBroadcast == 1) { int32_t inp0 = static_cast<int32_t>(inputData0[i] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[0] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::min(inp0, inp1); } else { int32_t inp0 = static_cast<int32_t>(inputData0[i] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[i] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::min(inp0, inp1); } int value = roundf((res + (1<<15)) / (1 << 16)) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (res < 0) { value = roundf((res - (1<<15)) / (1 << 16)) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); } if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNBinaryMaxInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { int res = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { int32_t inp0 = static_cast<int32_t>(inputData0[0] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[i] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::max(inp0, inp1); } else if (needBroadcast == 1) { int32_t inp0 = static_cast<int32_t>(inputData0[i] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[0] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::max(inp0, inp1); } else { int32_t inp0 = static_cast<int32_t>(inputData0[i] - offset - params->inputZeroPoint[0]) * static_cast<int32_t>(inputScalesInt32[0]); int32_t inp1 = static_cast<int32_t>(inputData1[i] - offset - params->inputZeroPoint[1]) * static_cast<int32_t>(inputScalesInt32[1]); res = std::max(inp0, inp1); } int value = (res + (1<<15)) / (1 << 16) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (res < 0) { value = (res - (1<<15)) / (1 << 16) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); } if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNBinarySqdInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = inputRaw0; const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { float inp0 = (inputData0[0] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } else if (needBroadcast == 1) { float inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[0] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } else { float inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; float inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } int value = (int)roundf(res * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, ssize_t minValue, ssize_t maxValue, int8_t* inputZeroPoint, int8_t* outputZeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) { #ifdef MNN_USE_SSE const uint8_t* srcPtr = (uint8_t*)src; uint8_t* dstPtr = (uint8_t*)dst; int offset = 128; #else const int8_t* srcPtr = src; int8_t* dstPtr = dst; int offset = 0; #endif int intputZeroPointValue = *inputZeroPoint + offset; int outputZeroPointValue = *outputZeroPoint + offset; int d = mShiftBits - 1; for (int z = 0; z < biasNumber; ++z) { auto dstZ = dstPtr + planeNumber * pack * z; const auto srcZ = srcPtr + planeNumber * pack * z; std::vector<int32_t> biasZ(pack), alphaZ(pack); for (int i = 0; i < pack; ++i) { biasZ[i] = *(bias + pack * z + i); alphaZ[i] = *(alpha + pack * z + i); } for (int p = 0; p < planeNumber; ++p) { auto dstX = dstZ + pack * p; const auto srcX = srcZ + pack * p; for (int i = 0; i < pack; ++i) { int32_t val = static_cast<int32_t>(srcX[i] - intputZeroPointValue) * alphaZ[i] + biasZ[i]; int valOut = roundf((val + (1<<d)) / (1 << mShiftBits)) + outputZeroPointValue; if (val < 0) { valOut = roundf((val - (1<<d)) / (1 << mShiftBits)) + outputZeroPointValue; } if (valOut > maxValue + offset) { valOut = maxValue + offset; } if (valOut < minValue + offset) { valOut = minValue + offset; } dstX[i] = valOut; } } } } #endif // #ifndef MNN_USE_NEON #ifndef MNN_USE_SSE void MNNInt8FunctionInit() { // do nothing } #endif // #ifndef MNN_USE_SSE template<int EP, int LP, int HP> static void _ArmBasicMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) { int number = info[0]; int eReal = info[1]; int eOutsideStride = info[2] / sizeof(float); int eDest = EP; int offset = info[3]; const int LUNIT = LP / sizeof(float); int realDstCount = info[4]; for (int n=0; n<number; ++n) { int e = el[4 * n + 0]; // to fill int l = el[4 * n + 1]; int eOffset = el[4 * n + 2]; // have filled int lOffset = el[4 * n + 3]; int lC = lOffset / LP; int lR = lOffset % LP; int eC = eOffset / eDest; int eR = eOffset % eDest; int eS = eDest - eR; // printf("e=%d, eC=%d, lC=%d, eR=%d, lR=%d\n", e, eC, lC, eR, lR); bool lastBag = false; int eOutsideStride4LastBag = eOutsideStride; if (realDstCount % EP > 0) { int jobsE = realDstCount - eOffset - e; if (jobsE == 0 || (jobsE < (realDstCount % EP))) { lastBag = true; } } auto dest = (int32_t*)(destOrigin + lC * eDest * LP + lR + eC * info[2] + eR * LP); auto source = (int32_t*)sourceGroup[n]; int lRemain = l / 4; int lR4 = lR / 4; int lS = LUNIT - lR4; if (lastBag && e + eR < EP) { int elast = ALIMAX(eR + e, realDstCount % EP); dest = (int32_t*)(destOrigin + lC * elast * LP + lR + eC * info[2] + eR * LP); } // Step for start int offsetLC = lC * LUNIT + lR / 4; if (lR4 > 0) { int step = ALIMIN(lS, lRemain); for (int x=0; x<step; ++x) { int eRemain = e; auto d = dest + x; auto s = source + x * eReal; if (eR > 0) { int eStep = ALIMIN(eRemain, eS); for (int yi=0; yi<eStep; ++yi) { d[yi * LUNIT] = s[yi * offset]; } eRemain-=eStep; if (!lastBag ||eRemain >= EP) { d += (eOutsideStride - eR * LUNIT); } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - ((offsetLC / LUNIT) * EP * LUNIT); d += (eOutsideStride4LastBag - eR * LUNIT + (offsetLC / LUNIT) * eFill * LUNIT); } s += eS * offset; } while (eRemain > 0) { int eStep = ALIMIN(eDest, eRemain); for (int yi=0; yi<eStep; ++yi) { d[yi * LUNIT] = s[yi * offset]; } eRemain-=eStep; if (!lastBag || eRemain >= EP) { d+= eOutsideStride; } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - ((offsetLC / LUNIT) * EP * LUNIT); d+= (eOutsideStride4LastBag + (offsetLC / LUNIT) * eFill * LUNIT); } s+= eStep * offset; } offsetLC++; } lRemain -= step; if (lastBag && e + eR < EP) { int eFill = ALIMAX(realDstCount % EP, e + eR); int nextLP = (eFill * LP - lR) / sizeof(int32_t); dest += nextLP; } else { int nextLP = (eDest * LP - lR) / sizeof(int32_t); dest += nextLP; } source += eReal * step; } while (lRemain > 0) { int step = ALIMIN(lRemain, LUNIT); for (int x=0; x<step; ++x) { int eRemain = e; auto d = dest + x; auto s = source + x * eReal; if (eR > 0) { int eStep = ALIMIN(eRemain, eS); for (int yi=0; yi<eStep; ++yi) { d[yi * LUNIT] = s[yi * offset]; } eRemain-=eStep; if (!lastBag ||eRemain >= EP) { d += (eOutsideStride - eR * LUNIT); } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - ((offsetLC / LUNIT) * EP * LUNIT); d += (eOutsideStride4LastBag - eR * LUNIT + (offsetLC / LUNIT) * eFill * LUNIT); } s += eS * offset; } while (eRemain > 0) { int eStep = ALIMIN(eDest, eRemain); for (int yi=0; yi<eStep; ++yi) { d[yi * LUNIT] = s[yi * offset]; } eRemain-=eStep; if (!lastBag || eRemain >= EP) { d+= eOutsideStride; } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - ((offsetLC / LUNIT) * EP * LUNIT); d+= (eOutsideStride4LastBag + (offsetLC / LUNIT) * eFill * LUNIT); } s+= eStep * offset; } offsetLC++; } lRemain -= step; if (lastBag && e + eR < EP) { int efill = ALIMAX(e + eR, realDstCount % EP); dest += efill * LUNIT; } else { dest += eDest * LUNIT; } source += eReal * step; } } } static void MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { *UNIT = GEMM_INT8_UNIT; *SRC_UNIT = GEMM_INT8_SRC_UNIT; *DST_XUNIT = GEMM_INT8_DST_XUNIT; } static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { *UNIT = 8; *SRC_UNIT = 4; *DST_XUNIT = 12; } static void MNNGetGemmUnitI8mm(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { *UNIT = 8; *SRC_UNIT = 8; *DST_XUNIT = 10; } template<int EP, int HP> static void _ArmBasicMNNPackC4ForMatMul_A_L4(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) { int number = info[0]; int eReal = info[1]; int eDest = EP; int offset = info[3]; const int LP = 4; int eOutsideStride = info[2] / sizeof(float); int kernelCountUnit = eOutsideStride; int realDstCount = info[4]; for (int n=0; n<number; ++n) { int e = el[4 * n + 0]; int l = el[4 * n + 1]; int eOffset = el[4 * n + 2]; int lOffset = el[4 * n + 3]; int eC = eOffset / EP; int eR = eOffset % EP; int eS = eDest - eR; bool lastBag = false; int eOutsideStride4LastBag = eOutsideStride; int eres = realDstCount - eOffset; if (realDstCount % EP > 0) { int jobsE = realDstCount - eOffset - e; if (jobsE == 0 || (jobsE < (realDstCount % EP))) { lastBag = true; } } auto dest = (int32_t*)(destOrigin + lOffset * eDest + eC * info[2] + eR * LP); auto source = (int32_t*)sourceGroup[n]; int lRemain = l / sizeof(float); if (lastBag && e + eR < EP) { int elast = ALIMIN(ALIMAX(eR + e, realDstCount % EP), EP); dest = (int32_t*)(destOrigin + lOffset * elast + eC * info[2] + eR * LP); } int offsetLC = lOffset / 4; for (int x=0; x<lRemain; ++x) { int eRemain = e; auto d = dest; auto s = source; if (1 == offset) { if (eR > 0) { int eStep = ALIMIN(eRemain, eS); ::memcpy(d, s, eStep * sizeof(int32_t)); eRemain-=eStep; if (!lastBag ||eRemain >= EP) { d += (eOutsideStride - eR); } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d += (eOutsideStride4LastBag - eR + offsetLC * eFill); } s += eS * offset; } while (eRemain > 0) { int eStep = ALIMIN(eDest, eRemain); ::memcpy(d, s, eStep * sizeof(int32_t)); eRemain-=eStep; if (!lastBag || eRemain >= EP) { d+= eOutsideStride; } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d+= (eOutsideStride4LastBag + offsetLC * eFill); } s+= eStep * offset; } } else { if (eR > 0) { int eStep = ALIMIN(eRemain, eS); for (int yi=0; yi<eStep; ++yi) { d[yi] = s[yi * offset]; } eRemain-=eStep; if (!lastBag ||eRemain >= EP) { d += (eOutsideStride - eR); } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d += (eOutsideStride4LastBag - eR + offsetLC * eFill); } s += eS * offset; } while (eRemain > 0) { int eStep = ALIMIN(eDest, eRemain); for (int yi=0; yi<eStep; ++yi) { d[yi] = s[yi * offset]; } eRemain-=eStep; if (!lastBag || eRemain >= EP) { d+= eOutsideStride; } else { int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0 eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float)); d+= (eOutsideStride4LastBag + offsetLC * eFill); } s+= eStep * offset; } } source += eReal; if (lastBag && e + eR < EP ) { // eR=0;eR>0 int efill = ALIMAX(e + eR, realDstCount % EP); dest += efill; } else { dest += eDest; } offsetLC++; } } } namespace MNN { static CoreInt8Functions* gCoreFunc = nullptr; void MNNCoreInt8FunctionInit() { /* CoreInt8Functions without sdot */ gCoreFunc = new CoreInt8Functions; // MatMul gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit; #ifdef MNN_LOW_MEMORY gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_16x4_w4_Unit; #endif // Im2Col gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<GEMM_INT8_DST_XUNIT, GEMM_INT8_SRC_UNIT, GEMM_INT8_UNIT>; // conv depthwise gCoreFunc->ConvDepthwiseLineInt8 = MNNLineDepthWiseInt8AddBiasScaleUnit; gCoreFunc->MNNFloat2Int8 = MNNFloat2Int8; gCoreFunc->MNNInt8ScaleToFloat = MNNInt8ScaleToFloat; // sparse gCoreFunc->MNNGetSparseQuantMatMulPackMode = MNNGetSparseQuantMatMulPackMode; gCoreFunc->MNNPackForSparseQuantMatMul_B = MNNPackForSparseQuantMatMul_B; gCoreFunc->MNNPackedSparseQuantMatMulEpx1 = MNNPackedSparseQuantMatMulEpx1; gCoreFunc->MNNPackedSparseQuantMatMulEpx4 = MNNPackedSparseQuantMatMulEpx4; gCoreFunc->MNNPackC4Int8ForMatMul_ASparse = _MNNPackC4Int8ForMatMul_ASparse; // pooling gCoreFunc->MNNAvgPoolInt8 = MNNAvgPoolInt8; gCoreFunc->MNNMaxPoolInt8 = MNNMaxPoolInt8; // ReluWithSlopeChannel gCoreFunc->MNNReluWithSlopeChannelInt8 = MNNReluWithSlopeChannelInt8; #if defined(__aarch64__) auto core = MNNGetCoreFunctions(); if (core->supportSDot) { // MatMul gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV82_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSdot; // Im2Col gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 8>; // ConvDepthwise gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3; core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM82; #if defined(MNN_LOW_MEMORY) #ifdef MNN_USE_ARMV82 gCoreFunc->DynamicQuanInput_ARM82 = DynamicQuanInput_ARM82; gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16; gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16; gCoreFunc->DynamicQuanInputAndReorder_ARM82 = DynamicQuanInputAndReorder_ARM82; #endif gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit; #endif } if (core->supportI8mm) { // MatMul gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm; core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM86; #if defined(MNN_LOW_MEMORY) gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit; #ifdef MNN_USE_ARMV82 gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16; gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16; #endif #endif // Im2Col gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; } #endif MNNInt8FunctionInit(); } CoreInt8Functions* MNNGetInt8CoreFunctions() { return gCoreFunc; } };