in source/backend/cpu/compute/Int8FunctionsOpt.cpp [581:1404]
void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
size_t eSize = sparseQuantParam[0];
size_t eP = sparseQuantParam[1];
size_t aStride = sparseQuantParam[2];
size_t l = sparseQuantParam[3];
size_t h = sparseQuantParam[4];
size_t cStride = sparseQuantParam[5];
const int32_t* bias = post->bias;
const float* scales = post->scale;
const int32_t maxValue = post->maxValue;
const int32_t minValue = post->minValue;
const int sparseBlockOC = 4;
const int8_t * a = A;
size_t ie = 0;
for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
a += diff;
const int8_t * w = B;
int8_t * blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
size_t ih = 0;
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
auto ihPack = ih >> 2;
auto c = blockC + ihPack * cStride;
int32_t initValue[4] = {0, 0, 0, 0};
if (nullptr != bias) {
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
}
int32_t acc0[4];
int32_t acc1[4];
int32_t acc2[4];
int32_t acc3[4];
int32_t acc4[4];
int32_t acc5[4];
int32_t acc6[4];
int32_t acc7[4];
int32_t acc8[4];
int32_t acc9[4];
int32_t acc10[4];
int32_t acc11[4];
int32_t acc12[4];
int32_t acc13[4];
int32_t acc14[4];
int32_t acc15[4];
memcpy(acc0, initValue, 4 * sizeof(int32_t));
memcpy(acc1, initValue, 4 * sizeof(int32_t));
memcpy(acc2, initValue, 4 * sizeof(int32_t));
memcpy(acc3, initValue, 4 * sizeof(int32_t));
memcpy(acc4, initValue, 4 * sizeof(int32_t));
memcpy(acc5, initValue, 4 * sizeof(int32_t));
memcpy(acc6, initValue, 4 * sizeof(int32_t));
memcpy(acc7, initValue, 4 * sizeof(int32_t));
memcpy(acc8, initValue, 4 * sizeof(int32_t));
memcpy(acc9, initValue, 4 * sizeof(int32_t));
memcpy(acc10, initValue, 4 * sizeof(int32_t));
memcpy(acc11, initValue, 4 * sizeof(int32_t));
memcpy(acc12, initValue, 4 * sizeof(int32_t));
memcpy(acc13, initValue, 4 * sizeof(int32_t));
memcpy(acc14, initValue, 4 * sizeof(int32_t));
memcpy(acc15, initValue, 4 * sizeof(int32_t));
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t a8 = a[8];
const int8_t a9 = a[9];
const int8_t a10 = a[10];
const int8_t a11 = a[11];
const int8_t a12 = a[12];
const int8_t a13 = a[13];
const int8_t a14 = a[14];
const int8_t a15 = a[15];
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {16});
// MNN_PRINT("\n");
a = a + diff;
for (int lane = 0; lane < 4; lane++) {
acc0[lane] += (int32_t)a0 * (int32_t)wv[lane];
acc1[lane] += (int32_t)a1 * (int32_t)wv[lane];
acc2[lane] += (int32_t)a2 * (int32_t)wv[lane];
acc3[lane] += (int32_t)a3 * (int32_t)wv[lane];
acc4[lane] += (int32_t)a4 * (int32_t)wv[lane];
acc5[lane] += (int32_t)a5 * (int32_t)wv[lane];
acc6[lane] += (int32_t)a6 * (int32_t)wv[lane];
acc7[lane] += (int32_t)a7 * (int32_t)wv[lane];
acc8[lane] += (int32_t)a8 * (int32_t)wv[lane];
acc9[lane] += (int32_t)a9 * (int32_t)wv[lane];
acc10[lane] += (int32_t)a10 * (int32_t)wv[lane];
acc11[lane] += (int32_t)a11 * (int32_t)wv[lane];
acc12[lane] += (int32_t)a12 * (int32_t)wv[lane];
acc13[lane] += (int32_t)a13 * (int32_t)wv[lane];
acc14[lane] += (int32_t)a14 * (int32_t)wv[lane];
acc15[lane] += (int32_t)a15 * (int32_t)wv[lane];
}
}
int8_t result0[4];
int8_t result1[4];
int8_t result2[4];
int8_t result3[4];
int8_t result4[4];
int8_t result5[4];
int8_t result6[4];
int8_t result7[4];
int8_t result8[4];
int8_t result9[4];
int8_t result10[4];
int8_t result11[4];
int8_t result12[4];
int8_t result13[4];
int8_t result14[4];
int8_t result15[4];
if (scales) {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc8[lane])), float(minValue))));
result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc9[lane])), float(minValue))));
result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc10[lane])), float(minValue))));
result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc11[lane])), float(minValue))));
result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc12[lane])), float(minValue))));
result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc13[lane])), float(minValue))));
result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc14[lane])), float(minValue))));
result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc15[lane])), float(minValue))));
}
} else {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
result8[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc8[lane]), minValue)));
result9[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc9[lane]), minValue)));
result10[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc10[lane]), minValue)));
result11[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc11[lane]), minValue)));
result12[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc12[lane]), minValue)));
result13[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc13[lane]), minValue)));
result14[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc14[lane]), minValue)));
result15[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc15[lane]), minValue)));
}
}
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));
memcpy(c + 4 * 8 , result8, 4 * sizeof(int8_t));
memcpy(c + 4 * 9 , result9, 4 * sizeof(int8_t));
memcpy(c + 4 * 10, result10, 4 * sizeof(int8_t));
memcpy(c + 4 * 11, result11, 4 * sizeof(int8_t));
memcpy(c + 4 * 12, result12, 4 * sizeof(int8_t));
memcpy(c + 4 * 13, result13, 4 * sizeof(int8_t));
memcpy(c + 4 * 14, result14, 4 * sizeof(int8_t));
memcpy(c + 4 * 15, result15, 4 * sizeof(int8_t));
}
blockC += (h >> 2) * cStride;
for (; ih < h; ih++) {
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
int32_t acc4 = initValue;
int32_t acc5 = initValue;
int32_t acc6 = initValue;
int32_t acc7 = initValue;
int32_t acc8 = initValue;
int32_t acc9 = initValue;
int32_t acc10 = initValue;
int32_t acc11 = initValue;
int32_t acc12 = initValue;
int32_t acc13 = initValue;
int32_t acc14 = initValue;
int32_t acc15 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t a8 = a[8];
const int8_t a9 = a[9];
const int8_t a10 = a[10];
const int8_t a11 = a[11];
const int8_t a12 = a[12];
const int8_t a13 = a[13];
const int8_t a14 = a[14];
const int8_t a15 = a[15];
const int8_t oneW = *w++;
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {16});
// MNN_PRINT("\n");
a = a + diff;
acc0 += (int32_t)a0 * (int32_t)oneW;
acc1 += (int32_t)a1 * (int32_t)oneW;
acc2 += (int32_t)a2 * (int32_t)oneW;
acc3 += (int32_t)a3 * (int32_t)oneW;
acc4 += (int32_t)a4 * (int32_t)oneW;
acc5 += (int32_t)a5 * (int32_t)oneW;
acc6 += (int32_t)a6 * (int32_t)oneW;
acc7 += (int32_t)a7 * (int32_t)oneW;
acc8 += (int32_t)a8 * (int32_t)oneW;
acc9 += (int32_t)a9 * (int32_t)oneW;
acc10 += (int32_t)a10 * (int32_t)oneW;
acc11 += (int32_t)a11 * (int32_t)oneW;
acc12 += (int32_t)a12 * (int32_t)oneW;
acc13 += (int32_t)a13 * (int32_t)oneW;
acc14 += (int32_t)a14 * (int32_t)oneW;
acc15 += (int32_t)a15 * (int32_t)oneW;
}
int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit
int8_t result1;
int8_t result2;
int8_t result3;
int8_t result4;
int8_t result5;
int8_t result6;
int8_t result7;
int8_t result8;
int8_t result9;
int8_t result10;
int8_t result11;
int8_t result12;
int8_t result13;
int8_t result14;
int8_t result15;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue))));
result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue))));
result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue))));
result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue))));
result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue))));
result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue))));
result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue))));
result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue));
result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue));
result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue));
result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue));
result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue));
result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue));
result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue));
result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
c[4 * 4] = result4;
c[4 * 5] = result5;
c[4 * 6] = result6;
c[4 * 7] = result7;
c[4 * 8] = result8;
c[4 * 9] = result9;
c[4 * 10] = result10;
c[4 * 11] = result11;
c[4 * 12] = result12;
c[4 * 13] = result13;
c[4 * 14] = result14;
c[4 * 15] = result15;
}
a += aStride;
}
if (eSize & 0x08) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
size_t ih = 0;
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
auto ihPack = ih >> 2;
auto c = blockC + ihPack * cStride;
int32_t initValue[4] = {0, 0, 0, 0};
if (nullptr != bias) {
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
}
int32_t acc0[4];
int32_t acc1[4];
int32_t acc2[4];
int32_t acc3[4];
int32_t acc4[4];
int32_t acc5[4];
int32_t acc6[4];
int32_t acc7[4];
memcpy(acc0, initValue, 4 * sizeof(int32_t));
memcpy(acc1, initValue, 4 * sizeof(int32_t));
memcpy(acc2, initValue, 4 * sizeof(int32_t));
memcpy(acc3, initValue, 4 * sizeof(int32_t));
memcpy(acc4, initValue, 4 * sizeof(int32_t));
memcpy(acc5, initValue, 4 * sizeof(int32_t));
memcpy(acc6, initValue, 4 * sizeof(int32_t));
memcpy(acc7, initValue, 4 * sizeof(int32_t));
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value[0-3]:, a value[0-7]:\n", ie, a - A, w - B - 1, c - C);
// formatMatrix(wv, {4});
// formatMatrix(a, {8});
// MNN_PRINT("\n");
a = a + diff;
for (int lane = 0; lane < 4; lane++) {
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
acc4[lane] += int32_t(a4) * int32_t(wv[lane]);
acc5[lane] += int32_t(a5) * int32_t(wv[lane]);
acc6[lane] += int32_t(a6) * int32_t(wv[lane]);
acc7[lane] += int32_t(a7) * int32_t(wv[lane]);
}
}
int8_t result0[4];
int8_t result1[4];
int8_t result2[4];
int8_t result3[4];
int8_t result4[4];
int8_t result5[4];
int8_t result6[4];
int8_t result7[4];
if (scales) {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc4[lane])), float(minValue))));
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc5[lane])), float(minValue))));
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc6[lane])), float(minValue))));
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc7[lane])), float(minValue))));
}
} else {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
result4[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc4[lane]), minValue)));
result5[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc5[lane]), minValue)));
result6[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc6[lane]), minValue)));
result7[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc7[lane]), minValue)));
}
}
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
memcpy(c + 4 * 4 , result4, 4 * sizeof(int8_t));
memcpy(c + 4 * 5 , result5, 4 * sizeof(int8_t));
memcpy(c + 4 * 6 , result6, 4 * sizeof(int8_t));
memcpy(c + 4 * 7 , result7, 4 * sizeof(int8_t));
}
blockC += (ih >> 2) * cStride;
for (; ih < h; ih++) {
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
int32_t acc4 = initValue;
int32_t acc5 = initValue;
int32_t acc6 = initValue;
int32_t acc7 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t oneW = *w++;
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {8});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
acc2 += int32_t(a2) * int32_t(oneW);
acc3 += int32_t(a3) * int32_t(oneW);
acc4 += int32_t(a4) * int32_t(oneW);
acc5 += int32_t(a5) * int32_t(oneW);
acc6 += int32_t(a6) * int32_t(oneW);
acc7 += int32_t(a7) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
int8_t result2;
int8_t result3;
int8_t result4;
int8_t result5;
int8_t result6;
int8_t result7;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
c[4 * 4] = result4;
c[4 * 5] = result5;
c[4 * 6] = result6;
c[4 * 7] = result7;
}
ie += 8;
a += 8;
}
if (eSize & 0x04) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
size_t ih = 0;
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
auto ihPack = ih >> 2;
auto c = blockC + ihPack * cStride;
int32_t initValue[4] = {0, 0, 0, 0};
if (nullptr != bias) {
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
}
int32_t acc0[4];
int32_t acc1[4];
int32_t acc2[4];
int32_t acc3[4];
memcpy(acc0, initValue, 4 * sizeof(int32_t));
memcpy(acc1, initValue, 4 * sizeof(int32_t));
memcpy(acc2, initValue, 4 * sizeof(int32_t));
memcpy(acc3, initValue, 4 * sizeof(int32_t));
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-3]:\n", ie, a - A, w - B - 1, c - C);
// formatMatrix(wv, {4});
// formatMatrix(a, {4});
// MNN_PRINT("\n");
a = a + diff;
for (int lane = 0; lane < 4; lane++) {
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
acc2[lane] += int32_t(a2) * int32_t(wv[lane]);
acc3[lane] += int32_t(a3) * int32_t(wv[lane]);
}
}
int8_t result0[4];
int8_t result1[4];
int8_t result2[4];
int8_t result3[4];
if (scales) {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc2[lane])), float(minValue))));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc3[lane])), float(minValue))));
}
} else {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
result2[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc2[lane]), minValue)));
result3[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc3[lane]), minValue)));
}
}
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
memcpy(c + 4 * 2 , result2, 4 * sizeof(int8_t));
memcpy(c + 4 * 3 , result3, 4 * sizeof(int8_t));
}
blockC += (ih >> 2) * cStride;
for (; ih < h; ih++) {
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t oneW = *w++;
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {4});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
acc2 += int32_t(a2) * int32_t(oneW);
acc3 += int32_t(a3) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
int8_t result2;
int8_t result3;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
}
ie += 4;
a += 4;
}
if (eSize & 0x02) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
size_t ih = 0;
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
auto ihPack = ih >> 2;
auto c = blockC + ihPack * cStride;
int32_t initValue[4] = {0, 0, 0, 0};
if (nullptr != bias) {
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
}
int32_t acc0[4];
int32_t acc1[4];
memcpy(acc0, initValue, 4 * sizeof(int32_t));
memcpy(acc1, initValue, 4 * sizeof(int32_t));
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
// formatMatrix(wv, {4});
// formatMatrix(a, {2});
// MNN_PRINT("\n");
a = a + diff;
for (int lane = 0; lane < 4; lane++) {
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
acc1[lane] += int32_t(a1) * int32_t(wv[lane]);
}
}
int8_t result0[4];
int8_t result1[4];
if (scales) {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc1[lane])), float(minValue))));
}
} else {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
result1[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc1[lane]), minValue)));
}
}
memcpy(c , result0, 4 * sizeof(int8_t)); // store continuous c
memcpy(c + 4 , result1, 4 * sizeof(int8_t));
}
blockC += (ih >> 2) * cStride;
for (; ih < h; ih++) {
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t oneW = *w++;
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {2});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
}
ie += 2;
a += 2;
}
if (eSize & 0x01) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// const float* a = blockA + diff;
a += diff;
const int8_t * w = B;
int8_t * blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
size_t ih = 0;
for (; ih < (h & (~0x03)); ih += sparseBlockOC) {
auto ihPack = ih >> 2;
auto c = blockC + ihPack * cStride;
int32_t initValue[4] = {0, 0, 0, 0};
if (nullptr != bias) {
memcpy(initValue, bias + ih, 4 * sizeof(int32_t));
}
int32_t acc0[4];
memcpy(acc0, initValue, 4 * sizeof(int32_t));
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t wv[4] = {*w++, *w++, *w++, *w++};
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:, a value[0-1]:\n", ie, a - A, w - B - 1, c - C);
// formatMatrix(wv, {4});
// formatMatrix(a, {16});
// MNN_PRINT("\n");
a = a + diff;
for (int lane = 0; lane < 4; lane++) {
acc0[lane] += int32_t(a0) * int32_t(wv[lane]);
}
}
int8_t result0[4];
if (scales) {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih + lane] * float(acc0[lane])), float(minValue))));
}
} else {
for (int lane = 0; lane < 4; lane++) {
result0[lane] = static_cast<int8_t>(roundf(std::max(std::min(maxValue, acc0[lane]), minValue)));
}
}
memcpy(c, result0, 4 * sizeof(int8_t)); // store continuous c
}
blockC += (ih >> 2) * cStride;
for (; ih < h; ih++) {
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t oneW = *w++;
// MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {1});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
}
int8_t result0;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
}
ie += 1;
// a += 1;
}
}