in source/backend/cpu/compute/Int8FunctionsOpt.cpp [176:579]
void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap) {
size_t eSize = sparseQuantParam[0];
size_t eP = sparseQuantParam[1];
size_t aStride = sparseQuantParam[2];
size_t l = sparseQuantParam[3];
size_t h = sparseQuantParam[4];
size_t cStride = sparseQuantParam[5];
const int32_t* bias = post->bias;
const float* scales = post->scale;
const int32_t maxValue = post->maxValue;
const int32_t minValue = post->minValue;
const int sparseBlockOC = 4;
const int8_t * a = A;
size_t ie = 0;
for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
a += diff;
const int8_t * w = B;
int8_t * blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (size_t ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
int32_t acc4 = initValue;
int32_t acc5 = initValue;
int32_t acc6 = initValue;
int32_t acc7 = initValue;
int32_t acc8 = initValue;
int32_t acc9 = initValue;
int32_t acc10 = initValue;
int32_t acc11 = initValue;
int32_t acc12 = initValue;
int32_t acc13 = initValue;
int32_t acc14 = initValue;
int32_t acc15 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t a8 = a[8];
const int8_t a9 = a[9];
const int8_t a10 = a[10];
const int8_t a11 = a[11];
const int8_t a12 = a[12];
const int8_t a13 = a[13];
const int8_t a14 = a[14];
const int8_t a15 = a[15];
const int8_t oneW = *w++;
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {16});
// MNN_PRINT("\n");
a = a + diff;
acc0 += (int32_t)a0 * (int32_t)oneW;
acc1 += (int32_t)a1 * (int32_t)oneW;
acc2 += (int32_t)a2 * (int32_t)oneW;
acc3 += (int32_t)a3 * (int32_t)oneW;
acc4 += (int32_t)a4 * (int32_t)oneW;
acc5 += (int32_t)a5 * (int32_t)oneW;
acc6 += (int32_t)a6 * (int32_t)oneW;
acc7 += (int32_t)a7 * (int32_t)oneW;
acc8 += (int32_t)a8 * (int32_t)oneW;
acc9 += (int32_t)a9 * (int32_t)oneW;
acc10 += (int32_t)a10 * (int32_t)oneW;
acc11 += (int32_t)a11 * (int32_t)oneW;
acc12 += (int32_t)a12 * (int32_t)oneW;
acc13 += (int32_t)a13 * (int32_t)oneW;
acc14 += (int32_t)a14 * (int32_t)oneW;
acc15 += (int32_t)a15 * (int32_t)oneW;
}
int8_t result0; // in assemmbly code, consider reuse acc0[0-8] bit
int8_t result1;
int8_t result2;
int8_t result3;
int8_t result4;
int8_t result5;
int8_t result6;
int8_t result7;
int8_t result8;
int8_t result9;
int8_t result10;
int8_t result11;
int8_t result12;
int8_t result13;
int8_t result14;
int8_t result15;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
result8 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc8)), float(minValue))));
result9 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc9)), float(minValue))));
result10 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc10)), float(minValue))));
result11 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc11)), float(minValue))));
result12 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc12)), float(minValue))));
result13 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc13)), float(minValue))));
result14 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc14)), float(minValue))));
result15 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc15)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
result8 = static_cast<int8_t>(std::max(std::min(maxValue, acc8), minValue));
result9 = static_cast<int8_t>(std::max(std::min(maxValue, acc9), minValue));
result10 = static_cast<int8_t>(std::max(std::min(maxValue, acc10), minValue));
result11 = static_cast<int8_t>(std::max(std::min(maxValue, acc11), minValue));
result12 = static_cast<int8_t>(std::max(std::min(maxValue, acc12), minValue));
result13 = static_cast<int8_t>(std::max(std::min(maxValue, acc13), minValue));
result14 = static_cast<int8_t>(std::max(std::min(maxValue, acc14), minValue));
result15 = static_cast<int8_t>(std::max(std::min(maxValue, acc15), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
c[4 * 4] = result4;
c[4 * 5] = result5;
c[4 * 6] = result6;
c[4 * 7] = result7;
c[4 * 8] = result8;
c[4 * 9] = result9;
c[4 * 10] = result10;
c[4 * 11] = result11;
c[4 * 12] = result12;
c[4 * 13] = result13;
c[4 * 14] = result14;
c[4 * 15] = result15;
}
a += aStride;
}
if (eSize & 0x08) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (size_t ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
int32_t acc4 = initValue;
int32_t acc5 = initValue;
int32_t acc6 = initValue;
int32_t acc7 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t a4 = a[4];
const int8_t a5 = a[5];
const int8_t a6 = a[6];
const int8_t a7 = a[7];
const int8_t oneW = *w++;
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-7]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {8});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
acc2 += int32_t(a2) * int32_t(oneW);
acc3 += int32_t(a3) * int32_t(oneW);
acc4 += int32_t(a4) * int32_t(oneW);
acc5 += int32_t(a5) * int32_t(oneW);
acc6 += int32_t(a6) * int32_t(oneW);
acc7 += int32_t(a7) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
int8_t result2;
int8_t result3;
int8_t result4;
int8_t result5;
int8_t result6;
int8_t result7;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
result4 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc4)), float(minValue))));
result5 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc5)), float(minValue))));
result6 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc6)), float(minValue))));
result7 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc7)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
result4 = static_cast<int8_t>(std::max(std::min(maxValue, acc4), minValue));
result5 = static_cast<int8_t>(std::max(std::min(maxValue, acc5), minValue));
result6 = static_cast<int8_t>(std::max(std::min(maxValue, acc6), minValue));
result7 = static_cast<int8_t>(std::max(std::min(maxValue, acc7), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
c[4 * 4] = result4;
c[4 * 5] = result5;
c[4 * 6] = result6;
c[4 * 7] = result7;
}
ie += 8;
a += 8;
}
if (eSize & 0x04) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (size_t ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
int32_t acc2 = initValue;
int32_t acc3 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t a2 = a[2];
const int8_t a3 = a[3];
const int8_t oneW = *w++;
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-3]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {4});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
acc2 += int32_t(a2) * int32_t(oneW);
acc3 += int32_t(a3) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
int8_t result2;
int8_t result3;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
result2 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc2)), float(minValue))));
result3 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc3)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
result2 = static_cast<int8_t>(std::max(std::min(maxValue, acc2), minValue));
result3 = static_cast<int8_t>(std::max(std::min(maxValue, acc3), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
c[4 * 2] = result2;
c[4 * 3] = result3;
}
ie += 4;
a += 4;
}
if (eSize & 0x02) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const int8_t* w = B;
int8_t* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (size_t ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
int32_t acc1 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t a1 = a[1];
const int8_t oneW = *w++;
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%d, a value[0-1]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {2});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
acc1 += int32_t(a1) * int32_t(oneW);
}
int8_t result0;
int8_t result1;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
result1 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc1)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
result1 = static_cast<int8_t>(std::max(std::min(maxValue, acc1), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
c[4] = result1;
}
ie += 2;
a += 2;
}
if (eSize & 0x01) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// const float* a = blockA + diff;
a += diff;
const int8_t * w = B;
int8_t * blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (size_t ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const int32_t initValue = nullptr != bias ? bias[ih] : 0;
int32_t acc0 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const int8_t a0 = a[0];
const int8_t oneW = *w++;
// MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%d, a value[0]:\n", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {1});
// MNN_PRINT("\n");
a = a + diff;
acc0 += int32_t(a0) * int32_t(oneW);
}
int8_t result0;
if (scales) {
result0 = static_cast<int8_t>(roundf(std::max(std::min(float(maxValue), scales[ih] * float(acc0)), float(minValue))));
} else {
result0 = static_cast<int8_t>(std::max(std::min(maxValue, acc0), minValue));
}
// how to store faster: st4 / transpose /
c[0] = result0;
}
ie += 1;
// a += 1;
}
}