ErrorCode DenseConvolutionTiledImpl::onResize()

in source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp [423:718]


ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
    CPUConvolution::onResize(inputs, outputs);
    auto input   = inputs[0];
    auto weight  = inputs[1];
    Tensor *bias = nullptr;
    if (inputs.size() > 2) {
        bias = inputs[2];
    }
    auto core    = static_cast<CPUBackend *>(backend())->functions();
    int bytes    = core->bytes;
    float weightBytes  = bytes;
    int unit     = core->pack;
    int matmulBytes = bytes;
    if (core->matmulBytes != 0) {
        matmulBytes = core->matmulBytes;
    }
    auto packA   = core->MNNPackC4ForMatMul_A;
    int eP, lP, hP;
    getPackParameter(&eP, &lP, &hP, core);
    auto matmulUnit   = core->MNNPackedMatMul;
    auto matmulRemain = core->MNNPackedMatMulRemain;
    const uint8_t* dequantAlpha = nullptr;
    const uint8_t* dequantBias = nullptr;
    auto ic       = input->channel();
    auto icC4     = UP_DIV(ic, unit);
    auto L        = ic * mCommon->kernelY() * mCommon->kernelX();
    auto tileC    = std::max(unit, hP);
    int blockSize = L;
    int blockNum  = 1;
    float halfStride = 1;
    size_t weightStride = 0;
#ifdef MNN_LOW_MEMORY
    if (mResource && mResource->mDequantize.bits <= 8) {
        MNN_ASSERT(mResource->mDequantize.bits == 8);
        DenseConvolutionTiledExecutor::selectLowMemoryMatmulFunc(&matmulUnit, &matmulRemain, &weightBytes, mResource->mDequantize.bits, core);
        int scaleSize = mResource->mDequantize.mScaleBias->size() / (2 * bytes);
        blockNum = scaleSize / (mResource->hU * mResource->hP);
        blockSize /= blockNum;
        dequantAlpha = mResource->mDequantize.mScaleBias->host<uint8_t>();
        dequantBias = dequantAlpha + scaleSize * bytes;
        weightStride = (L - blockSize) * hP;
    }
#endif
    auto kernel_width      = mCommon->kernelX();
    auto kernel_height     = mCommon->kernelY();
    auto output      = outputs[0];
    auto batch       = output->batch();
    int threadNumber = ((CPUBackend *)backend())->threadNumber();
    
    int  LRoundup = ROUND_UP(L, lP);
    int  LRoundupC4 = UP_DIV(LRoundup, unit);
    auto outputChannel = output->channel();
    auto oC4      = UP_DIV(outputChannel, tileC);
    auto ocUp4    = ROUND_UP(outputChannel, hP);
    auto kernelSize               = mCommon->kernelX() * mCommon->kernelY();

    ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParameters, mCommon, input, output, mPadX, mPadY, core, nullptr);
    mTempBufferTranspose.buffer().type          = halide_type_of<uint8_t>();
    mTempBufferTranspose.buffer().dimensions    = 2;
    mTempBufferTranspose.buffer().dim[0].extent = threadNumber;
    mTempBufferTranspose.buffer().dim[1].extent = UP_DIV(L, lP) * lP * eP * matmulBytes;
    TensorUtils::setLinearLayout(&mTempBufferTranspose);
    auto plane    = mIm2ColParameters.ow * mIm2ColParameters.oh * batch;
    int tileCount = UP_DIV(plane, eP);
    mConvPerfconfig = bestTileConvolutionConfig(mCommon, input, output, threadNumber, backend());
    bool success = backend()->onAcquireBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
    if (!success) {
        return OUT_OF_MEMORY;
    }

    auto bufferAlloc   = static_cast<CPUBackend *>(backend())->getBufferAllocator();
    auto maxLine       = UP_DIV(eP, mIm2ColParameters.ow) + 1;
    auto tempPtr = bufferAlloc->alloc(kernelSize * maxLine * threadNumber * (4 * sizeof(int32_t) + sizeof(float *)));
    if (tempPtr.invalid()) {
        return OUT_OF_MEMORY;
    }
    backend()->onReleaseBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
    bufferAlloc->free(tempPtr);

    auto postParameters    = getPostParameters();
    mFunction.first        = threadNumber;

    if (mConvPerfconfig.isParallelInner) {
        auto rt = static_cast<const CPURuntime*>(backend()->getRuntime());
        std::vector<int> ocC4ParralSize(threadNumber + 1);
        ocC4ParralSize[0] = 0;
        static_cast<CPUBackend *>(backend())->computeDivideSizes(oC4, ocC4ParralSize.data()+1);
        mFunction.second = [=](int placeholder) {
        const float* biasPtr = bias ? bias->host<float>() : nullptr;
        auto gemmBuffer = mTempBufferTranspose.host<uint8_t>() + mTempBufferTranspose.stride(0) * 0;
        auto srcPtr     = (float const **)(tempPtr.ptr() + 0 * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
        auto el         = (int32_t *)(srcPtr + kernelSize * maxLine);
        auto weightPtr = weight->host<uint8_t>();

        constexpr int InfoSize = 4;
        int32_t shapeInfo[InfoSize];
        int32_t* info = shapeInfo;
        info[1] = mIm2ColParameters.iw * mIm2ColParameters.ih * batch;
        info[2] = eP;
        info[3] = mIm2ColParameters.strideX;
        size_t shapeParameters[PARAMETERSIZE];
        size_t* parameters = shapeParameters;
        parameters[0]          = eP * bytes;
        parameters[1]          = blockSize;
        parameters[2]          = outputChannel;
        parameters[3]          = plane * unit * bytes;
        parameters[4]          = 0;
        parameters[5]          = weightStride; // Only used when block quant
        parameters[6]          = 0;

        auto dstOrigin = output->host<uint8_t>();
        auto srcOrigin = input->host<uint8_t>();
        std::vector<int> im2colParallelSize(threadNumber + 1);
        im2colParallelSize[0] = 0;

        for (int x = 0; x < tileCount; x += 1) {
            int start  = (int)x * eP;
            int remain = plane - start;
            int xC     = remain > eP ? eP : remain;
            auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo(srcPtr, el, start, xC, mIm2ColParameters, srcOrigin, bytes);
            int number    = res.first;
            bool needZero = res.second;
            info[0] = number;
            if (needZero || lP != 1) {
                ::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0));
            }
            info[0] = 1;
            int hw4Stride = info[1] * unit * bytes;
            static_cast<CPUBackend *>(backend())->computeDivideSizes(number * icC4, im2colParallelSize.data() + 1);
            im2colParallelSize[0] = 0;
            MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                int threadEL[4];
                int ticSta = im2colParallelSize[tId];
                int ticEnd = im2colParallelSize[tId+1];
                for(int tic_inumber = ticSta; tic_inumber < ticEnd; tic_inumber++) {
                        int inumber = tic_inumber / icC4;
                        int t_ic = tic_inumber % icC4;
                        memcpy(threadEL, el + 4 * inumber, 4 * sizeof(int));
                        threadEL[1] = std::min(ic - (t_ic * unit), unit);
                        const float* source = (const float*)((const uint8_t*)(srcPtr[inumber]) + t_ic * hw4Stride);
                        auto gemmDest = gemmBuffer + t_ic * unit * eP * bytes;
                        packA((float *)gemmDest, &source, info, threadEL);
                }
            }
            MNN_CONCURRENCY_END();

            if (xC == eP) {
                MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                    size_t paraParameters[PARAMETERSIZE];
                    memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t));
                    for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) {
                        int ocIndex = t_oc * tileC;
                        auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes);
                        auto _weightFloatPtr = reinterpret_cast<const float*>(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes));
                        auto _biasFloatPtr = reinterpret_cast<const float*>(reinterpret_cast<const uint8_t*>(biasPtr) + ocIndex * bytes);
                        paraParameters[2] = std::min(outputChannel - ocIndex, tileC);
                        auto k = reinterpret_cast<const uint8_t*>(dequantAlpha + ocIndex * bytes);
                        auto b = reinterpret_cast<const uint8_t*>(dequantBias + ocIndex * bytes);
                        const float* relufp32 = nullptr;
                        const float* exeBiasPtr = nullptr;
                        int finishedL = 0;
                        int wquantStride = 0;
                        auto _weightPtr = reinterpret_cast<const int8_t*>(_weightFloatPtr);
                        uint8_t*  _APtr      = reinterpret_cast<uint8_t*>(gemmBuffer);
                        for (int bk = 0; bk < blockNum; ++bk) {
                            paraParameters[6] = bk;
                            if (bk == blockNum - 1) {
                                relufp32 = postParameters.data();
                                exeBiasPtr = _biasFloatPtr;
                            }
                            finishedL = blockSize * bk;
                            wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
                            matmulUnit(_dstFloatPtr, (float*)(_APtr + eP * finishedL * bytes), (float*)(_weightPtr + wquantStride), paraParameters, relufp32, exeBiasPtr, (float*)(k + bk * ocUp4 * bytes), (float*)(b + bk * ocUp4 * bytes));
                        }
                    }
                }
                MNN_CONCURRENCY_END();
            } else {
                MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                    size_t paraParameters[PARAMETERSIZE];
                    memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t));
                    for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) {
                        int ocIndex = t_oc * tileC;
                        auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes);
                        auto _weightFloatPtr = reinterpret_cast<const float*>(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes));
                        auto _biasFloatPtr = reinterpret_cast<const float*>(reinterpret_cast<const uint8_t*>(biasPtr) + ocIndex * bytes);
                        paraParameters[2] = std::min(outputChannel - ocIndex, tileC);
                        auto k = reinterpret_cast<const uint8_t*>(dequantAlpha + ocIndex * bytes);
                        auto b = reinterpret_cast<const uint8_t*>(dequantBias + ocIndex * bytes);
                        const float* relufp32 = nullptr;
                        const float* exeBiasPtr = nullptr;
                        int finishedL = 0;
                        int wquantStride = 0;
                        const int8_t* _weightPtr = reinterpret_cast<const int8_t*>(_weightFloatPtr);
                        uint8_t*  _APtr      = reinterpret_cast<uint8_t*>(gemmBuffer);
                        for (int bk = 0; bk < blockNum; ++bk) {
                            paraParameters[6] = bk;
                            if (bk == blockNum - 1) {
                                relufp32 = postParameters.data();
                                exeBiasPtr = _biasFloatPtr;
                            }
                            finishedL = blockSize * bk;
                            wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
                            matmulRemain(_dstFloatPtr, (float*)(_APtr + eP * finishedL * bytes), (float*)(_weightPtr + wquantStride), xC, paraParameters, relufp32, exeBiasPtr, (float*)(k + bk * ocUp4 * bytes), (float*)(b + bk * ocUp4 * bytes));
                        }
                    }
                }
                MNN_CONCURRENCY_END();
            }

        }
    };

    } else {
        std::vector<int> divides(threadNumber + 1);
        divides[0] = 0;

        static_cast<CPUBackend *>(backend())->computeDivideSizes(tileCount, divides.data() + 1);

        mFunction.second       = [=](int tId) {
            const float* biasPtr = bias ? bias->host<float>() : nullptr;
            auto gemmBuffer = mTempBufferTranspose.host<uint8_t>() + mTempBufferTranspose.stride(0) * tId;
            auto srcPtr     = (float const **)(tempPtr.ptr() + tId * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
            auto el         = (int32_t *)(srcPtr + kernelSize * maxLine);
            auto weightPtr = weight->host<float>();
            int32_t info[4];
            info[1] = mIm2ColParameters.iw * mIm2ColParameters.ih * batch;
            info[2] = eP;
            info[3] = mIm2ColParameters.strideX;
            size_t parameters[PARAMETERSIZE];
            parameters[0]          = eP * bytes;
            parameters[1]          = blockSize;
            parameters[2]          = outputChannel;
            parameters[3]          = plane * unit * bytes;
            parameters[4]          = 0;
            parameters[5]          = weightStride; // Only used when block quant
            parameters[6]          = 0;

            auto dstOrigin = output->host<uint8_t>();
            auto srcOrigin = input->host<uint8_t>();
            int tEnd = divides[tId+1];
            int tStart = divides[tId];
            for (int x = (int)tStart; x < tEnd; ++x) {
                int start  = (int)x * eP;
                int remain = plane - start;
                int xC     = remain > eP ? eP : remain;
                auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo(srcPtr, el, start, xC, mIm2ColParameters, srcOrigin, bytes);
                auto number = res.first;
                bool needZero = res.second;
                info[0] = number;
                if (needZero || lP != 1) {
                    ::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0));
                }

                if (number > 0) {
                    packA((float *)gemmBuffer, srcPtr, info, el);
                }

                int finishedL = 0;
                int wquantStride = 0;
                int8_t* _weightPtr = reinterpret_cast<int8_t*>(weightPtr);
                auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + start * unit * bytes);
                const float* relufp32 = nullptr;
                const float* exeBiasPtr = nullptr;
                if (xC == eP) {
                    // matmulUnit(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, parameters, postParameters.data(), biasPtr, k, b);
                    for (int bk = 0; bk < blockNum; ++bk) {
                        parameters[6] = bk;
                        if (bk == blockNum - 1) {
                            relufp32 = postParameters.data();
                            exeBiasPtr = biasPtr;
                        }
                        finishedL = blockSize * bk;
                        wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
                        
                        matmulUnit(_dstFloatPtr, (float*)(gemmBuffer + bytes * eP * finishedL), (float*)(_weightPtr + wquantStride), parameters, relufp32, exeBiasPtr, (float*)(dequantAlpha + bk * ocUp4 * bytes), (float*)(dequantBias + bk * ocUp4 * bytes));
                    }
                } else {
                    for (int bk = 0; bk < blockNum; ++bk) {
                        parameters[6] = bk;
                        if (bk == blockNum - 1) {
                            relufp32 = postParameters.data();
                            exeBiasPtr = biasPtr;
                        }
                        finishedL = blockSize * bk;
                        wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
                        
                        matmulRemain(_dstFloatPtr, (float*)(gemmBuffer + eP * bytes * finishedL), (float*)(_weightPtr + wquantStride), xC, parameters, relufp32, exeBiasPtr, (float*)(dequantAlpha + bk * ocUp4 * bytes), (float*)(dequantBias + bk * ocUp4 * bytes ));
                    }
                    // matmulRemain(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, xC, parameters, postParameters.data(), biasPtr, k, b);
                }
            }
        };
    }
    return NO_ERROR;
}