in source/backend/cpu/compute/CommonOptFunction.cpp [1270:1593]
void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) {
auto eP = parameter[0] / sizeof(float);
MNN_ASSERT((eP & 0x03) == 0); // In sparse calculate, eP should be evenly divided by 4
auto h = parameter[2];
auto l = parameter[1];
auto cStride = parameter[3] / sizeof(float);
auto aStride = eP * l;
auto hRemain = parameter[4];
auto bExtraStride = parameter[5] / sizeof(float);
auto bStride = bExtraStride + l * 4;
auto hC4 = UP_DIV(h, 4);
float minValue = -std::numeric_limits<float>().max();
float maxValue = std::numeric_limits<float>().max();
if (nullptr != postParameters) {
minValue = postParameters[2];
maxValue = postParameters[3];
}
// MNN_PRINT("MNNPackedSparseMatMul eP:%lu, eSize:%lu, l:%lu, h:%lu, cStride:%lu, aStride:%lu\n", eP, eSize, l, h, cStride, aStride);
const float* a = A;
size_t ie = 0;
for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
a += diff;
const float* w = B;
float* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (auto ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const float initValue = nullptr != bias ? bias[ih] : 0;
float acc0 = initValue;
float acc1 = initValue;
float acc2 = initValue;
float acc3 = initValue;
float acc4 = initValue;
float acc5 = initValue;
float acc6 = initValue;
float acc7 = initValue;
float acc8 = initValue;
float acc9 = initValue;
float acc10 = initValue;
float acc11 = initValue;
float acc12 = initValue;
float acc13 = initValue;
float acc14 = initValue;
float acc15 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const float a0 = a[0];
const float a1 = a[1];
const float a2 = a[2];
const float a3 = a[3];
const float a4 = a[4];
const float a5 = a[5];
const float a6 = a[6];
const float a7 = a[7];
const float a8 = a[8];
const float a9 = a[9];
const float a10 = a[10];
const float a11 = a[11];
const float a12 = a[12];
const float a13 = a[13];
const float a14 = a[14];
const float a15 = a[15];
const float oneW = *w++;
// MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {16});
// MNN_PRINT("\n");
a = a + diff;
acc0 += a0 * oneW;
acc1 += a1 * oneW;
acc2 += a2 * oneW;
acc3 += a3 * oneW;
acc4 += a4 * oneW;
acc5 += a5 * oneW;
acc6 += a6 * oneW;
acc7 += a7 * oneW;
acc8 += a8 * oneW;
acc9 += a9 * oneW;
acc10 += a10 * oneW;
acc11 += a11 * oneW;
acc12 += a12 * oneW;
acc13 += a13 * oneW;
acc14 += a14 * oneW;
acc15 += a15 * oneW;
}
acc0 = std::max(std::min(maxValue, acc0), minValue);
acc1 = std::max(std::min(maxValue, acc1), minValue);
acc2 = std::max(std::min(maxValue, acc2), minValue);
acc3 = std::max(std::min(maxValue, acc3), minValue);
acc4 = std::max(std::min(maxValue, acc4), minValue);
acc5 = std::max(std::min(maxValue, acc5), minValue);
acc6 = std::max(std::min(maxValue, acc6), minValue);
acc7 = std::max(std::min(maxValue, acc7), minValue);
acc8 = std::max(std::min(maxValue, acc8), minValue);
acc9 = std::max(std::min(maxValue, acc9), minValue);
acc10 = std::max(std::min(maxValue, acc10), minValue);
acc11 = std::max(std::min(maxValue, acc11), minValue);
acc12 = std::max(std::min(maxValue, acc12), minValue);
acc13 = std::max(std::min(maxValue, acc13), minValue);
acc14 = std::max(std::min(maxValue, acc14), minValue);
acc15 = std::max(std::min(maxValue, acc15), minValue);
// how to store faster: st4 / transpose /
c[0] = acc0;
c[4] = acc1;
c[4 * 2] = acc2;
c[4 * 3] = acc3;
c[4 * 4] = acc4;
c[4 * 5] = acc5;
c[4 * 6] = acc6;
c[4 * 7] = acc7;
c[4 * 8] = acc8;
c[4 * 9] = acc9;
c[4 * 10] = acc10;
c[4 * 11] = acc11;
c[4 * 12] = acc12;
c[4 * 13] = acc13;
c[4 * 14] = acc14;
c[4 * 15] = acc15;
}
a += aStride;
}
// const float* blockA = A + ie * l;
if (eSize & 0x08) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// a = blockA + diff;
a += diff;
const float* w = B;
float* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (auto ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const float initValue = nullptr != bias ? bias[ih] : 0;
float acc0 = initValue;
float acc1 = initValue;
float acc2 = initValue;
float acc3 = initValue;
float acc4 = initValue;
float acc5 = initValue;
float acc6 = initValue;
float acc7 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const float a0 = a[0];
const float a1 = a[1];
const float a2 = a[2];
const float a3 = a[3];
const float a4 = a[4];
const float a5 = a[5];
const float a6 = a[6];
const float a7 = a[7];
const float oneW = *w++;
// MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-7]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {8});
// MNN_PRINT("\n");
a = a + diff;
acc0 += a0 * oneW;
acc1 += a1 * oneW;
acc2 += a2 * oneW;
acc3 += a3 * oneW;
acc4 += a4 * oneW;
acc5 += a5 * oneW;
acc6 += a6 * oneW;
acc7 += a7 * oneW;
}
acc0 = std::max(std::min(maxValue, acc0), minValue);
acc1 = std::max(std::min(maxValue, acc1), minValue);
acc2 = std::max(std::min(maxValue, acc2), minValue);
acc3 = std::max(std::min(maxValue, acc3), minValue);
acc4 = std::max(std::min(maxValue, acc4), minValue);
acc5 = std::max(std::min(maxValue, acc5), minValue);
acc6 = std::max(std::min(maxValue, acc6), minValue);
acc7 = std::max(std::min(maxValue, acc7), minValue);
// how to store faster: st4 / transpose /
c[0] = acc0;
c[4] = acc1;
c[4 * 2] = acc2;
c[4 * 3] = acc3;
c[4 * 4] = acc4;
c[4 * 5] = acc5;
c[4 * 6] = acc6;
c[4 * 7] = acc7;
}
ie += 8;
a += 8;
}
if (eSize & 0x04) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// const float* a = blockA + diff;
a += diff;
const float* w = B;
float* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (auto ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const float initValue = nullptr != bias ? bias[ih] : 0;
float acc0 = initValue;
float acc1 = initValue;
float acc2 = initValue;
float acc3 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const float a0 = a[0];
const float a1 = a[1];
const float a2 = a[2];
const float a3 = a[3];
const float oneW = *w++;
// MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-3]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {4});
// MNN_PRINT("\n");
a = a + diff;
acc0 += a0 * oneW;
acc1 += a1 * oneW;
acc2 += a2 * oneW;
acc3 += a3 * oneW;
}
acc0 = std::max(std::min(maxValue, acc0), minValue);
acc1 = std::max(std::min(maxValue, acc1), minValue);
acc2 = std::max(std::min(maxValue, acc2), minValue);
acc3 = std::max(std::min(maxValue, acc3), minValue);
// how to store faster: st4 / transpose /
c[0] = acc0;
c[4] = acc1;
c[4 * 2] = acc2;
c[4 * 3] = acc3;
}
ie += 4;
a += 4;
}
if (eSize & 0x02) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// const float* a = blockA + diff;
a += diff;
const float* w = B;
float* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (auto ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const float initValue = nullptr != bias ? bias[ih] : 0;
float acc0 = initValue;
float acc1 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const float a0 = a[0];
const float a1 = a[1];
const float oneW = *w++;
// MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-1]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {2});
// MNN_PRINT("\n");
a = a + diff;
acc0 += a0 * oneW;
acc1 += a1 * oneW;
}
acc0 = std::max(std::min(maxValue, acc0), minValue);
acc1 = std::max(std::min(maxValue, acc1), minValue);
// how to store faster: st4 / transpose /
c[0] = acc0;
c[4] = acc1;
}
ie += 2;
a += 2;
}
if (eSize & 0x01) {
const int* dataOffset = dataOffsetMap;
const int diff = *dataOffset++;
// const float* a = blockA + diff;
a += diff;
const float* w = B;
float* blockC = C + (ie << 2);
const unsigned int* nnz = NNZMap;
for (auto ih = 0; ih < h; ih++) {
auto ihPack = ih >> 2;
auto ihSubIndex = ih & 0x03;
auto c = blockC + ihPack * cStride + ihSubIndex;
const float initValue = nullptr != bias ? bias[ih] : 0;
float acc0 = initValue;
const int lElement = *nnz++;
for (auto il = 0; il < lElement; il++) {
const int diff = *dataOffset++;
const float a0 = a[0];
const float oneW = *w++;
// MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%f, a value[0]:", ie, a - A, w - B - 1, c - C, oneW);
// formatMatrix(a, {1});
// MNN_PRINT("\n");
a = a + diff;
acc0 += a0 * oneW;
}
acc0 = std::max(std::min(maxValue, acc0), minValue);
// how to store faster: st4 / transpose /
c[0] = acc0;
}
ie += 1;
// a += 1;
}
return;
}