ErrorCode ConvolutionPackWinograd::onResize()

in source/backend/cpu/compute/ConvolutionPackWinograd.cpp [216:560]


ErrorCode ConvolutionPackWinograd::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
    CPUConvolution::onResize(inputs, outputs);
    int threadNumber = ((CPUBackend*)(backend()))->threadNumber();
    mTempBuffer->setLength(0, threadNumber);
    mGemmMidBuffer->setLength(0, threadNumber);
    mTransformMidBuffer->setLength(0, threadNumber);
    // FUNC_PRINT(mA->length(1));
    bool success = backend()->onAcquireBuffer(mTempBuffer.get(), Backend::DYNAMIC);
    success      = success && backend()->onAcquireBuffer(mGemmMidBuffer.get(), Backend::DYNAMIC);
    success      = success && (backend()->onAcquireBuffer(mTransformMidBuffer.get(), Backend::DYNAMIC));
    backend()->onReleaseBuffer(mTempBuffer.get(), Backend::DYNAMIC);
    backend()->onReleaseBuffer(mTransformMidBuffer.get(), Backend::DYNAMIC);
    backend()->onReleaseBuffer(mGemmMidBuffer.get(), Backend::DYNAMIC);
    if (!success) {
        return OUT_OF_MEMORY;
    }
    auto core = static_cast<CPUBackend*>(backend())->functions();
    int pack = core->pack, bytes = core->bytes;

    auto input   = inputs[0];
    auto output  = outputs[0];
    auto dstUnit = mA->length(1); // m
    auto srcUnit = mA->length(0); // n
    int ePack, lPack, hPack;
    core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);

    auto srcUnit2 = srcUnit * srcUnit;
    auto alphaXStride = srcUnit * ePack * pack;
    auto IC4alpha2Stride = srcUnit2 * ePack * pack;

    int ow   = output->width();
    int oh   = output->height();
    int iw   = input->width();
    int ih   = input->height();
    int ic_4 = UP_DIV(input->channel(), pack);
    int dc_4 = UP_DIV(output->channel(), pack);
    int batch = input->batch();
    // MNN_PRINT("%d, %d\n", srcUnit, dstUnit);

    int padY = mPadY;
    int padX = mPadX;

    auto wUnit = UP_DIV(ow, dstUnit); // ow / m
    auto hUnit = UP_DIV(oh, dstUnit); // oh / m

    auto totalCount   = wUnit * hUnit * batch;
    // MNN_PRINT("ow=%d, oh=%d\n", ow, oh);
    
    std::vector<int> divides(threadNumber+1);
    static_cast<CPUBackend *>(backend())->computeDivideSizes(totalCount, divides.data()+1);
    divides[0] = 0;
    auto midBuffer0Bytes = srcUnit2 * pack * bytes;
    bool allow_x86_bf16_winograd = true;
#ifdef MNN_USE_SSE
    allow_x86_bf16_winograd = bytes != 2; // only bf16 has length of 2 byte on x86. fp16 dosnot exist.
#endif

    auto weight    = mResource->mWeight->host<uint8_t>();
    auto bias      = mResource->mBias->host<uint8_t>();
    mMainFunction.first = threadNumber;
    mMainFunction.second = [=](int tId, const uint8_t* inputOrigin, uint8_t* dstOrigin) {
        int tSta = divides[tId];
        int tFin = divides[tId+1];
        if (tSta >= tFin) {
            return;
        }
        int eRemain = (tFin-tSta) % ePack;
        std::vector<size_t> parameters(6);
        parameters[1] = input->channel();
        parameters[2] = output->channel();
        parameters[4] = 0;
        parameters[5] = 0;
        parameters[0] = eRemain * bytes;
        parameters[3] = ePack * pack * bytes;

        std::vector<size_t> parametersRemain = parameters;
        parametersRemain[0] = eRemain * bytes;
        parametersRemain[3] = eRemain * pack * bytes;

        auto srcOrigin = inputOrigin;
        auto _srcOrigin = mTempBuffer->host<uint8_t>() + tId * mTempBuffer->stride(0);
        auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
        auto midBuffer0 = mTransformMidBuffer->host<uint8_t>() + tId * mTransformMidBuffer->stride(0);
        auto midBufferStride1 = mTransformMidBuffer->stride(1);
        auto weightStride = mResource->mWeight->stride(0);
        auto midBuffer1 = midBuffer0 + midBuffer0Bytes;
        for (int xIndex = tSta; xIndex < tFin; xIndex+=ePack) {
            int xReamin = tFin - xIndex;
            int xC      = xReamin > ePack ? ePack : xReamin;

            const bool fuseTransformPack = (xC * FULSE_THRESHHOLD_DENOMINATOR >= FULSE_THRESHHOLD_NUMERATOR * ePack) && allow_x86_bf16_winograd && nullptr != mSourceTransformPack && core->matmulBytes == 0;
            /*Source Transform Begin*/
#ifndef MNN_WINO_TRANFORM_TEST_CLOSE
            {
                int sourceZStep = iw * ih * batch * pack;
                int oyBegin = xIndex / wUnit;
                int oxBegin = xIndex % wUnit;
                int oyEnd = (xIndex + xC-1) / wUnit;
                int remain = xC;
                int destSOffset = 0;
                if (fuseTransformPack) {
                    for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
                        int hIndex = hbIndex % hUnit;
                        int bIndex = hbIndex / hUnit;
                        int step = ALIMIN(wUnit - oxBegin, remain);
                        int srcY  = hIndex * dstUnit - padY;
                        int ey    = ALIMIN(srcY + srcUnit, ih) - srcY;
                        int sy    = ALIMAX(0, srcY) - srcY;
                        auto srcStartY = srcOrigin + (srcY * iw + bIndex * iw * ih) * pack * bytes;

                        for (int si=0; si<step; ++si) {
                            auto wIndex = si + oxBegin;
                            int srcX  = wIndex * dstUnit - padX;
                            int sx    = ALIMAX(0, srcX) - srcX;
                            int ex    = ALIMIN(srcX + srcUnit, iw) - srcX;
                            int count = pack * (ex - sx);

                            auto srcStart = srcStartY + srcX * pack * bytes;
                            auto midBuffer1Offset = midBuffer1 + destSOffset;

                            if (ex - sx == srcUnit && ey - sy == srcUnit) {
                                for (int z = 0; z < ic_4; ++z) {
                                    auto srcZ = srcStart + z * sourceZStep * bytes;
                                    // Transform
                                    mSourceUnrollTransform((const float*)srcZ, (float*)midBuffer1Offset, iw * pack, ePack * pack, pack, alphaXStride);
                                    midBuffer1Offset += IC4alpha2Stride * bytes;
                                }
                            } else {
                                for (int z = 0; z < ic_4; ++z) {
                                    // Extract
                                    auto srcZ = srcStart + z * sourceZStep * bytes;
                                    ::memset(midBuffer0, 0, midBuffer0Bytes);
                                    if (count > 0) {
                                        for (int yy = sy; yy < ey; ++yy) {
                                            auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
                                            auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
                                            ::memcpy(dst_yy, src_yy, count * bytes);
                                        }
                                    }

                                    mSourceUnrollTransform((const float*)midBuffer0, (float*)midBuffer1Offset, srcUnit * pack, ePack * pack, pack, alphaXStride);
                                    midBuffer1Offset += IC4alpha2Stride * bytes;
                                }
                            }
                            destSOffset += pack * bytes;
                        }
                        oxBegin = 0;
                        remain -= step;
                    }
                } else {
                    int dstZStep    = xC * pack;  // hUnit*wUnit * 4
                    int unitStep    = ic_4 * xC * pack; //  C/4 * hUnit*wUnit * 4
                    for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
                        int hIndex = hbIndex % hUnit;
                        int bIndex = hbIndex / hUnit;
                        int step = ALIMIN(wUnit - oxBegin, remain);
                        int srcY  = hIndex * dstUnit - padY;
                        int ey    = ALIMIN(srcY + srcUnit, ih) - srcY; //h dim pack element length
                        int sy    = ALIMAX(0, srcY) - srcY;  // first y element
                        auto srcStartY = srcOrigin + (srcY * iw + bIndex * iw * ih) * pack * bytes;
                        for (int si=0; si<step; ++si) {
                            auto wIndex = si + oxBegin;
                            int srcX  = wIndex * dstUnit - padX;
                            int sx    = ALIMAX(0, srcX) - srcX;
                            int ex    = ALIMIN(srcX + srcUnit, iw) - srcX;
                            int count = pack * (ex - sx);

                            auto srcStart = srcStartY + srcX * pack * bytes;
                            auto dst_x = _srcOrigin + destSOffset;
                            if (ex - sx == srcUnit && ey - sy == srcUnit) {
                                for (int z = 0; z < ic_4; ++z) {
                                    auto srcZ = srcStart + z * sourceZStep * bytes;
                                    // Transform

                                    auto dstZ = dst_x + z * dstZStep * bytes;
                                    mSourceUnrollTransform((const float*)srcZ, (float*)midBuffer1, iw * pack, pack, pack, pack * srcUnit);
                                    mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, unitStep, pack, unitStep * srcUnit);
                                }
                            } else {
                                for (int z = 0; z < ic_4; ++z) {
                                    // Extract
                                    auto srcZ = srcStart + z * sourceZStep * bytes;
                                    ::memset(midBuffer0, 0, midBufferStride1);
                                    if (count > 0) {
                                        for (int yy = sy; yy < ey; ++yy) {
                                            auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
                                            auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
                                            ::memcpy(dst_yy, src_yy, count * bytes);
                                        }
                                    }

                                    auto dstZ = dst_x + z * dstZStep * bytes;

                                    mSourceUnrollTransform((const float*)midBuffer0, (float*)midBuffer1, srcUnit * pack, pack, pack, pack * srcUnit);
                                    mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, unitStep, pack, unitStep * srcUnit);

                                }
                            }
                            destSOffset += pack * bytes;
                        }
                        oxBegin = 0;
                        remain -= step;
                    }
                }
            }

#endif
            auto* _dstOrigin = _srcOrigin;
            if (fuseTransformPack) {
                _dstOrigin += ePack * srcUnit2 * ic_4 * pack * bytes;
                if (xC != ePack) {
                    auto midTransformPtr = midBuffer1 + xC * pack * bytes;
                    for (int i = 0; i < ic_4 * srcUnit2; ++i) {
                        memset(midTransformPtr, 0, (ePack - xC) * pack * bytes);
                        midTransformPtr += ePack * pack * bytes;
                    }
                }
                for (int iNw = 0; iNw < srcUnit; ++iNw) { // i_Nw
                    auto midTransformPtr = midBuffer1 + iNw * alphaXStride * bytes;
                    auto unitsGemmbuffer = gemmBuffer;
                    for (int z = 0; z < ic_4; ++z) { // ic_4
                        mSourceTransformPack((float*)midTransformPtr, (float*)unitsGemmbuffer, ePack * pack * ic_4);
                        unitsGemmbuffer += ePack * pack * bytes;
                        midTransformPtr += IC4alpha2Stride * bytes;
                    }
                    // Previous tranform requires xC aligned with EPack, xC should be Epack;
                    for (int iNh = 0; iNh < srcUnit; ++iNh) { // i_Nh, gemm
                        auto unitsGemmbuffer = gemmBuffer + iNh * ic_4 * pack * ePack * bytes;
                        auto _dstFloatPtr = (float*)(_dstOrigin + (iNh * srcUnit + iNw) * dc_4 * pack * ePack * bytes);
                        auto _weightFloatPtr = (const float*)(weight + (iNh * srcUnit + iNw) * weightStride);
                        core->MNNPackedMatMul(_dstFloatPtr, (float*)unitsGemmbuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr);
                    }
                }
            } else {
                /*Source Transform End*/
                // // Multi
                _dstOrigin += xC * srcUnit2 * ic_4 * pack * bytes;

                int32_t info[4];
                info[0] = 1;
                info[1] = xC;
                info[2] = xC;
                info[3] = 1;
                int32_t el[4];
                el[0] = xC;
                el[1] = parameters[1];
                el[2] = 0;
                el[3] = 0;
                if (xC == ePack) {
                    for (int i = 0; i < srcUnit2; ++i) {
                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
                        auto _weightFloatPtr = (const float*)(weight + i * weightStride);
                        core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);

                        core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr);
                    }
                } else {
                    for (int i = 0; i < srcUnit2; ++i) {
                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
                        auto _weightFloatPtr = (const float*)(weight + i * weightStride);
                        core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);
                        core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr, nullptr, nullptr);
                    }
                }
            }

#ifndef MNN_WINO_TRANFORM_TEST_CLOSE
            /* Dest Transform And Post Treat Begin */
            {
                auto DestUnrollTransform = mDestUnrollTransform.get();

                int srcZStep = (fuseTransformPack ? ePack : xC) * pack;
                int unitStep = (fuseTransformPack ? ePack : xC) * dc_4 * pack;
                int dstZStep = ow * oh * pack * batch;
                int oyBegin = xIndex / wUnit;
                int oxBegin = xIndex % wUnit;
                int oyEnd = (xIndex + xC-1) / wUnit;
                int remain = xC;
                auto dstS = _dstOrigin;
                for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
                    int hIndex = hbIndex % hUnit;
                    int bIndex = hbIndex / hUnit;
                    int step = std::min(wUnit - oxBegin, remain);
                    int dstY = hIndex * dstUnit;
                    int ey = ALIMIN(dstY + dstUnit, oh) - dstY;
                    for (int si=0; si<step; ++si) {
                        auto wIndex = si + oxBegin;
                        auto srcXi = dstS + pack * si * bytes;
                        int dstX = wIndex * dstUnit;
                        auto dstStart = dstOrigin + (dstX + dstY * ow + bIndex * ow * oh) * pack * bytes;
                        int ex = ALIMIN(dstX + dstUnit, ow) - dstX;

                        int count = ex * pack;
                        if (ex == dstUnit) {
                            for (int z = 0; z < dc_4; ++z) {
                                auto dstZAddr = dstStart + z * dstZStep * bytes;
                                auto srcZ     = srcXi + z * srcZStep * bytes;

                                DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)midBuffer0, nullptr, nullptr, unitStep, dstUnit * pack, srcUnit * unitStep, pack);
                                DestUnrollTransform[ey]((const float*)midBuffer0, (float*)dstZAddr,  nullptr, nullptr, pack, pack * ow, pack * dstUnit, pack);
                            }
                        } else {
                            for (int z = 0; z < dc_4; ++z) {
                                auto dstZAddr = dstStart + z * dstZStep * bytes;
                                auto srcZ     = srcXi + z * srcZStep * bytes;

                                DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)midBuffer0, nullptr, nullptr, unitStep, dstUnit * pack, srcUnit * unitStep, pack);
                                DestUnrollTransform[ey]((const float*)midBuffer0, (float*)midBuffer1,  nullptr, nullptr, pack, pack * dstUnit, pack * dstUnit, pack);

                                for (int yy = 0; yy < ey; ++yy) {
                                    auto dstYAddr = dstZAddr + yy * pack * ow * bytes;
                                    auto srcYAddr = midBuffer1 + yy * pack * dstUnit * bytes;
                                    ::memcpy(dstYAddr, srcYAddr, count * bytes);
                                }
                            }
                        }
                    }
                    oxBegin = 0;
                    remain -= step;
                    dstS += pack * step * bytes;
                }
            }
#endif
            /*Dest Transform And Post Treat End*/
        }
    };
    std::vector<int> postDivides(threadNumber+1);
    static_cast<CPUBackend *>(backend())->computeDivideSizes(dc_4, postDivides.data()+1);
    postDivides[0] = 0;

    mPostFunction.first = threadNumber;
    mPostFunction.second = [=](int tId, uint8_t* outputOrigin) {
        auto dstOrigin = outputOrigin;
        int tSta = postDivides[tId];
        int tFin = postDivides[tId+1];
        for (int dy=tSta; dy < tFin; ++dy) {
            auto dataFloatPtr = (float*)(dstOrigin + ow * oh * batch * dy * pack * bytes);
            auto biasFloatPtr = (const float*)(bias + pack * dy * bytes);
            core->MNNAxByClampBroadcastUnit(dataFloatPtr, dataFloatPtr, biasFloatPtr, ow * oh * batch, 0, 0, 1,  mPostParameters.data());
        }
    };
    return NO_ERROR;
}