void MNNPackedSparseMatMulEpx1()

in source/backend/cpu/compute/CommonOptFunction.cpp [1270:1593]


void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) {

    auto eP = parameter[0] / sizeof(float);
    MNN_ASSERT((eP & 0x03) == 0); // In sparse calculate, eP should be evenly divided by 4
    auto h = parameter[2];
    auto l = parameter[1];
    auto cStride = parameter[3] / sizeof(float);
    auto aStride = eP * l;
    auto hRemain = parameter[4];
    auto bExtraStride = parameter[5] / sizeof(float);
    auto bStride = bExtraStride + l * 4;
    auto hC4 = UP_DIV(h, 4);
    float minValue = -std::numeric_limits<float>().max();
    float maxValue = std::numeric_limits<float>().max();
    if (nullptr != postParameters) {
        minValue = postParameters[2];
        maxValue = postParameters[3];
    }
    // MNN_PRINT("MNNPackedSparseMatMul eP:%lu, eSize:%lu, l:%lu, h:%lu, cStride:%lu, aStride:%lu\n", eP, eSize, l, h, cStride, aStride);

    const float* a = A;
    size_t ie = 0;
    for (ie = 0; ie < eSize && eP <= eSize; ie += eP) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;
        for (auto ih = 0; ih < h; ih++) {
            auto ihPack = ih >> 2;
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihPack * cStride + ihSubIndex;
            const float initValue = nullptr != bias ? bias[ih] : 0;
            float acc0 = initValue;
            float acc1 = initValue;
            float acc2 = initValue;
            float acc3 = initValue;
            float acc4 = initValue;
            float acc5 = initValue;
            float acc6 = initValue;
            float acc7 = initValue;
            float acc8 = initValue;
            float acc9 = initValue;
            float acc10 = initValue;
            float acc11 = initValue;
            float acc12 = initValue;
            float acc13 = initValue;
            float acc14 = initValue;
            float acc15 = initValue;
            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {

                const int diff = *dataOffset++;
                const float a0 = a[0];
                const float a1 = a[1];
                const float a2 = a[2];
                const float a3 = a[3];
                const float a4 = a[4];
                const float a5 = a[5];
                const float a6 = a[6];
                const float a7 = a[7];
                const float a8 = a[8];
                const float a9 = a[9];
                const float a10 = a[10];
                const float a11 = a[11];
                const float a12 = a[12];
                const float a13 = a[13];
                const float a14 = a[14];
                const float a15 = a[15];

                const float oneW = *w++;

                // MNN_PRINT("16-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-15]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {16});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += a0 * oneW;
                acc1 += a1 * oneW;
                acc2 += a2 * oneW;
                acc3 += a3 * oneW;
                acc4 += a4 * oneW;
                acc5 += a5 * oneW;
                acc6 += a6 * oneW;
                acc7 += a7 * oneW;
                acc8 += a8 * oneW;
                acc9 += a9 * oneW;
                acc10 += a10 * oneW;
                acc11 += a11 * oneW;
                acc12 += a12 * oneW;
                acc13 += a13 * oneW;
                acc14 += a14 * oneW;
                acc15 += a15 * oneW;
            }
            acc0  = std::max(std::min(maxValue, acc0), minValue);
            acc1  = std::max(std::min(maxValue, acc1), minValue);
            acc2  = std::max(std::min(maxValue, acc2), minValue);
            acc3  = std::max(std::min(maxValue, acc3), minValue);
            acc4  = std::max(std::min(maxValue, acc4), minValue);
            acc5  = std::max(std::min(maxValue, acc5), minValue);
            acc6  = std::max(std::min(maxValue, acc6), minValue);
            acc7  = std::max(std::min(maxValue, acc7), minValue);
            acc8  = std::max(std::min(maxValue, acc8), minValue);
            acc9  = std::max(std::min(maxValue, acc9), minValue);
            acc10 = std::max(std::min(maxValue, acc10), minValue);
            acc11 = std::max(std::min(maxValue, acc11), minValue);
            acc12 = std::max(std::min(maxValue, acc12), minValue);
            acc13 = std::max(std::min(maxValue, acc13), minValue);
            acc14 = std::max(std::min(maxValue, acc14), minValue);
            acc15 = std::max(std::min(maxValue, acc15), minValue);

            // how to store faster: st4 / transpose /
            c[0] = acc0;
            c[4] = acc1;
            c[4 * 2] = acc2;
            c[4 * 3] = acc3;
            c[4 * 4] = acc4;
            c[4 * 5] = acc5;
            c[4 * 6] = acc6;
            c[4 * 7] = acc7;
            c[4 * 8] = acc8;
            c[4 * 9] = acc9;
            c[4 * 10] = acc10;
            c[4 * 11] = acc11;
            c[4 * 12] = acc12;
            c[4 * 13] = acc13;
            c[4 * 14] = acc14;
            c[4 * 15] = acc15;
        }
        a += aStride;
    }
    // const float* blockA = A + ie * l;
    if (eSize & 0x08) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // a = blockA + diff;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;
        for (auto ih = 0; ih < h; ih++) {
            auto ihPack = ih >> 2;
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihPack * cStride + ihSubIndex;
            const float initValue = nullptr != bias ? bias[ih] : 0;
            float acc0 = initValue;
            float acc1 = initValue;
            float acc2 = initValue;
            float acc3 = initValue;
            float acc4 = initValue;
            float acc5 = initValue;
            float acc6 = initValue;
            float acc7 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const float a0 = a[0];
                const float a1 = a[1];
                const float a2 = a[2];
                const float a3 = a[3];
                const float a4 = a[4];
                const float a5 = a[5];
                const float a6 = a[6];
                const float a7 = a[7];
                const float oneW = *w++;
                // MNN_PRINT("8-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-7]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {8});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += a0 * oneW;
                acc1 += a1 * oneW;
                acc2 += a2 * oneW;
                acc3 += a3 * oneW;
                acc4 += a4 * oneW;
                acc5 += a5 * oneW;
                acc6 += a6 * oneW;
                acc7 += a7 * oneW;
            }
            acc0  = std::max(std::min(maxValue, acc0), minValue);
            acc1  = std::max(std::min(maxValue, acc1), minValue);
            acc2  = std::max(std::min(maxValue, acc2), minValue);
            acc3  = std::max(std::min(maxValue, acc3), minValue);
            acc4  = std::max(std::min(maxValue, acc4), minValue);
            acc5  = std::max(std::min(maxValue, acc5), minValue);
            acc6  = std::max(std::min(maxValue, acc6), minValue);
            acc7  = std::max(std::min(maxValue, acc7), minValue);
            // how to store faster: st4 / transpose /
            c[0] = acc0;
            c[4] = acc1;
            c[4 * 2] = acc2;
            c[4 * 3] = acc3;
            c[4 * 4] = acc4;
            c[4 * 5] = acc5;
            c[4 * 6] = acc6;
            c[4 * 7] = acc7;
        }
        ie += 8;
        a += 8;
    }

    if (eSize & 0x04) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // const float* a = blockA + diff;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;
        for (auto ih = 0; ih < h; ih++) {
            auto ihPack = ih >> 2;
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihPack * cStride + ihSubIndex;
            const float initValue = nullptr != bias ? bias[ih] : 0;
            float acc0 = initValue;
            float acc1 = initValue;
            float acc2 = initValue;
            float acc3 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const float a0 = a[0];
                const float a1 = a[1];
                const float a2 = a[2];
                const float a3 = a[3];
                const float oneW = *w++;
                // MNN_PRINT("4-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-3]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {4});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += a0 * oneW;
                acc1 += a1 * oneW;
                acc2 += a2 * oneW;
                acc3 += a3 * oneW;
            }
            acc0  = std::max(std::min(maxValue, acc0), minValue);
            acc1  = std::max(std::min(maxValue, acc1), minValue);
            acc2  = std::max(std::min(maxValue, acc2), minValue);
            acc3  = std::max(std::min(maxValue, acc3), minValue);
            // how to store faster: st4 / transpose /
            c[0] = acc0;
            c[4] = acc1;
            c[4 * 2] = acc2;
            c[4 * 3] = acc3;
        }
        ie += 4;
        a += 4;
    }
    if (eSize & 0x02) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // const float* a = blockA + diff;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;
        for (auto ih = 0; ih < h; ih++) {
            auto ihPack = ih >> 2;
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihPack * cStride + ihSubIndex;
            const float initValue = nullptr != bias ? bias[ih] : 0;
            float acc0 = initValue;
            float acc1 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const float a0 = a[0];
                const float a1 = a[1];
                const float oneW = *w++;
                // MNN_PRINT("2-loop: ie:%zu, a offset:%ld, w offset:%ld, c offset:%ld, w value:%f, a value[0-1]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {2});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += a0 * oneW;
                acc1 += a1 * oneW;
            }
            acc0  = std::max(std::min(maxValue, acc0), minValue);
            acc1  = std::max(std::min(maxValue, acc1), minValue);
            // how to store faster: st4 / transpose /
            c[0] = acc0;
            c[4] = acc1;
        }
        ie += 2;
        a += 2;
    }
    if (eSize & 0x01) {
        const int* dataOffset = dataOffsetMap;
        const int diff = *dataOffset++;
        // const float* a = blockA + diff;
        a += diff;
        const float* w = B;
        float* blockC = C + (ie << 2);
        const unsigned int* nnz = NNZMap;
        for (auto ih = 0; ih < h; ih++) {
            auto ihPack = ih >> 2;
            auto ihSubIndex = ih & 0x03;
            auto c = blockC + ihPack * cStride + ihSubIndex;
            const float initValue = nullptr != bias ? bias[ih] : 0;
            float acc0 = initValue;

            const int lElement = *nnz++;
            for (auto il = 0; il < lElement; il++) {
                const int diff = *dataOffset++;
                const float a0 = a[0];
                const float oneW = *w++;

                // MNN_PRINT("1-loop: ie:%zu, a offset:%ld, c offset:%ld, w offset:%ld, w value:%f, a value[0]:", ie, a - A, w - B - 1, c - C, oneW);
                // formatMatrix(a, {1});
                // MNN_PRINT("\n");
                a = a + diff;
                acc0 += a0 * oneW;
            }
            acc0  = std::max(std::min(maxValue, acc0), minValue);
            // how to store faster: st4 / transpose /
            c[0] = acc0;
        }
        ie += 1;
        // a += 1;
    }

    return;
}