ErrorCode ConvolutionPackFreeWinograd::onExecute()

in source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp [77:622]


ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
    auto core = static_cast<CPUBackend*>(backend())->functions();
    int pack = core->pack, bytes = core->bytes;

    auto input   = inputs[0];
    auto output  = outputs[0];
    auto dstUnit = mA->length(1);
    auto srcUnit = mA->length(0);
    int ePackMax, lPack, hPack;
    core->MNNGetMatMulPackMode(&ePackMax, &lPack, &hPack);
    int ePack = mConvPerfconfig.ePack;

    auto srcUnit2 = srcUnit * srcUnit;
    auto alphaXStride = srcUnit * ePack * pack;
    auto IC4alpha2Stride = srcUnit2 * ePack * pack;

    int ow   = output->width();
    int oh   = output->height();
    int iw   = input->width();
    int ih   = input->height();
    int oc = output->channel();
    int ic = input->channel();
    int ic_roundup = ROUND_UP(ic, lPack);
    int ic_4 = UP_DIV(input->channel(), pack);
    int dc_4 = UP_DIV(output->channel(), pack);
    int batch = input->batch();

    int padY = mPadY;
    int padX = mPadX;

    auto wUnit = UP_DIV(ow, dstUnit);
    auto hUnit = UP_DIV(oh, dstUnit);

    auto totalCount   = wUnit * hUnit * batch;
    int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1);
    int eRemain = totalCount % ePack;
    int tileCount = UP_DIV(totalCount, mConvPerfconfig.eTile);

    std::vector<size_t> parameters(7);
    parameters[0] = eRemain * bytes;
    parameters[1] = input->channel();
    parameters[2] = output->channel();
    parameters[3] = ePack * pack * bytes;
    parameters[4] = 0;
    parameters[5] = 0;
    parameters[6] = 0;

    std::vector<size_t> parametersRemain = parameters;
    parametersRemain[3]                  = eRemain * pack * bytes;

    std::vector<size_t> Tile2MatMulParameters = {
        static_cast<size_t>(ePack * ic_4 * pack * bytes),
        static_cast<size_t>(ic),
        0,
        0,
        static_cast<size_t>(ic_roundup * mConvPerfconfig.hPack * bytes),
        static_cast<size_t>(mConvPerfconfig.hPack * bytes),
        0};

    auto inputOrigin     = input->host<uint8_t>();
    auto outputOrigin    = output->host<uint8_t>();
    auto srcOrigin       = inputOrigin;
    auto dstOrigin       = outputOrigin;
    auto midBuffer0Bytes = srcUnit2 * pack * bytes;

    bool allow_x86_bf16_winograd = true;
#ifdef MNN_USE_SSE
    allow_x86_bf16_winograd = bytes != 2;
#endif

    using ElementType = float;
    auto weight    = mResource->mWeight->host<uint8_t>();
    auto bias      = mResource->mBias->host<uint8_t>();


        auto _srcOrigin = mTempBuffer->host<uint8_t>();
        auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>());
        auto midBuffer0 = mTransformMidBuffer->host<uint8_t>();
        auto midBuffer1 = midBuffer0 + midBuffer0Bytes;

        auto parallelInnerSourceFunction = [&](int tId, int tIndex) {

            int eTile = mConvPerfconfig.eTile;
            int hPackDynamic = mConvPerfconfig.hPack;
            int ic_pack = ROUND_UP(ic, pack);
            int xIndex  = (int)tIndex * eTile;
            int xReamin = totalCount - xIndex;
            int eTileReal = xReamin > eTile ? eTile : xReamin;

            /*Source Transform Begin*/
            const int bTransStride = wUnit * hUnit;
            const int ib_stride = iw * ih;
            const int pack_stride = pack * bytes;

            const int ICUnitStep    = ic_4 * eTileReal * pack;
            const int sourceZStep = ib_stride * batch * pack_stride;
            const int IcBufferOffset = mTransformMidBuffer->stride(0);

            for (int tile_k_z = tId; tile_k_z < ic_4 * eTileReal; tile_k_z += threadNumber) {
                int z = tile_k_z / eTileReal;
                int eTileNumber = tile_k_z % eTileReal;
                int tile_k = eTileNumber + xIndex;
                int bIndex = tile_k / bTransStride;
                int hwIndex = tile_k % bTransStride;
                int hIndex = (hwIndex / wUnit);
                int wIndex = (hwIndex % wUnit);
                int iEpack = eTileNumber % ePack;
                int iETile = eTileNumber - iEpack;
                int ePackSegment = fmin(ePack, eTileReal - iETile);
                int ihIndex = hIndex * dstUnit - padY;
                int iwIndex = wIndex * dstUnit - padX;
                int ey    = ALIMIN(ihIndex + srcUnit, ih) - ihIndex;
                int sy    = ALIMAX(0, ihIndex) - ihIndex;
                int ex    = ALIMIN(iwIndex + srcUnit, iw) - iwIndex;
                int sx    = ALIMAX(0, iwIndex) - iwIndex;
                int count = pack_stride * (ex - sx);
                auto srcZ = srcOrigin + (iwIndex + ihIndex * iw + bIndex * ib_stride) * pack_stride + z * sourceZStep;
                auto dstZ = _srcOrigin + (iETile * ic_4 + z * ePackSegment + iEpack) * pack_stride;
                if (ex - sx == srcUnit && ey - sy == srcUnit) {

                    auto icMidBuffer1 = midBuffer1 + tId * IcBufferOffset;
                    mSourceUnrollTransform((const float*)srcZ, (float*)icMidBuffer1, iw * pack, pack, pack, pack * srcUnit);
                    mSourceUnrollTransform((const float*)icMidBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
                } else {
                        // Extract

                    auto icMidBuffer1 = midBuffer1 + tId * IcBufferOffset;
                    auto icMidBuffer0 = midBuffer0 + tId * IcBufferOffset;
                    ::memset(icMidBuffer0, 0, mTransformMidBuffer->stride(1));
                    if (count > 0) {
                        for (int yy = sy; yy < ey; ++yy) {
                            auto dst_yy = icMidBuffer0 + (yy * srcUnit + sx) * pack_stride;
                            auto src_yy = srcZ + (iw * yy + sx) * pack_stride;
                            ::memcpy(dst_yy, src_yy, count);
                        }
                    }

                    mSourceUnrollTransform((const float*)icMidBuffer0, (float*)icMidBuffer1, srcUnit * pack, pack, pack, pack * srcUnit);
                    mSourceUnrollTransform((const float*)icMidBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
                }

            }
        };

    auto parallelInnerPackFreeMultiplyFunction = [&](int tId, int tIndex) {

        int eTile = mConvPerfconfig.eTile;
        int hPackDynamic = mConvPerfconfig.hPack;

        int xIndex  = (int)tIndex * eTile;
        int xReamin = totalCount - xIndex;
        int eTileReal = xReamin > eTile ? eTile : xReamin;

        int tLast = eTileReal % ePack;
        int tBlock = eTileReal - tLast;
        const int oc_hpack = UP_DIV(oc, hPackDynamic);
        const int oc_pack_coeff = hPackDynamic / pack;
        const int weightStride = mResource->mWeight->stride(0);
        const int pack_stride = pack * bytes;

        auto threadParameters = Tile2MatMulParameters;
        auto threadParametersRemain = threadParameters;
        threadParameters[6] =  tBlock;
        threadParametersRemain[6] = tLast;
        threadParameters[3] = eTileReal * pack_stride;
        threadParametersRemain[3] = threadParameters[3];

        // copy pointer out
        auto MaxATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[ePack - 1];
        auto TailATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[tLast - 1];
        auto MaxATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[ePack - 1];
        auto TailATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[tLast - 1];
        auto MaxATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[ePack - 1];
        auto TailATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[tLast - 1];

        auto* _dstOrigin = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack * bytes;

        // srcUnit2, oc
        for (int i_oc_src = tId; i_oc_src < srcUnit2 * oc_hpack; i_oc_src += threadNumber) {
            int t_oc_mul = i_oc_src % oc_hpack;
            int i = i_oc_src / oc_hpack;

            int t_oc = t_oc_mul * oc_pack_coeff;
            int ocValidPack = ALIMIN(t_oc + oc_pack_coeff, dc_4) - t_oc;
            // calculate address
            auto srcTemp = (_srcOrigin + i * ic_4 * eTileReal * pack * bytes);
            auto _weightFloatPtr = (const float*)(weight + i * weightStride + (t_oc * ic_roundup * pack) * bytes);
            auto _dstFloatPtr = (_dstOrigin + (i * dc_4 + t_oc) * eTileReal * pack * bytes);

#ifdef PROFILE_DETAIL
            macs[tId] += eTileReal * (2 * ic) * (ocValidPack) * pack;
#endif

            if (tBlock) {
                switch (ocValidPack) {
                    case 1:
                        MaxATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                        break;
                    case 2:
                        MaxATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                        break;
                    case 3:
                        MaxATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                        break;
                }
                srcTemp += tBlock * ic_4 * pack * bytes;
                _dstFloatPtr += tBlock * pack * bytes;
            }
            if (tLast) {

                switch (ocValidPack) {
                    case 1:
                        TailATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                        break;
                    case 2:
                        TailATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                        break;
                    case 3:
                        TailATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                        break;
                }
            }

        }
    };

        auto parallelInnerMultiplyFunction = [&](int tId, int tIndex) {
            int xIndex  = (int)tIndex * ePack;
            int xReamin = totalCount - xIndex;
            int xC      = xReamin > ePack ? ePack : xReamin;
            auto* _dstOrigin = _srcOrigin + xC * srcUnit2 * ic_4 * pack * bytes;

                /*Source Transform End*/
                // Multi
                int32_t info[4];
                info[0] = 1;
                info[1] = xC;
                info[2] = xC;
                info[3] = 1;
                int32_t el[4];
                el[0] = xC;
                el[1] = parameters[1];
                el[2] = 0;
                el[3] = 0;
                if (xC == ePackMax) {
                    for (int i = tId; i < srcUnit2; i+=threadNumber) {
                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
                        auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
                        core->MNNPackC4ForMatMul_A((float*)gemmBufferPtr, &srcTemp, info, el);
                    }
                    for (int i = tId; i < srcUnit2; i+=threadNumber) {
                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
                        auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
                        auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
                        core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr);
                    }
                } else {
                    for (int i = tId; i < srcUnit2; i+=threadNumber) {
                        auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
                        auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
                        core->MNNPackC4ForMatMul_A((float*)gemmBufferPtr, &srcTemp, info, el);
                    }
                    for (int i = tId; i < srcUnit2; i+=threadNumber) {
                        auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
                        auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
                        auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
                        core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr, nullptr, nullptr);
                    }
                }
            };

            /* Dest Transform And Post Treat Begin */
        auto parallelInnerDestFunction = [&](int tId, int tIndex) {

            auto DestUnrollTransform = mDestUnrollTransform.get();
            int eTile = mConvPerfconfig.eTile;
            int hPackDynamic = mConvPerfconfig.hPack;
            int ic_pack = ROUND_UP(ic, pack);
            int xIndex  = (int)tIndex * eTile;
            int xReamin = totalCount - xIndex;
            int eTileReal = xReamin > eTile ? eTile : xReamin;
            const int pack_stride = pack * bytes;

            const int transb_stride = wUnit * hUnit;
            const int ob_stride = ow * oh;
            const int srcTransZStep = eTileReal * pack_stride;
            const int OCUnitStep = eTileReal * pack * dc_4;
            const int dstZStep = ob_stride * batch * pack_stride;
            const auto ocBufferOffset = mTransformMidBuffer->stride(0);
            const auto srcOriginSegment = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack_stride;

            for (int tile_k_z = tId; tile_k_z < dc_4 * eTileReal; tile_k_z += threadNumber) {
                int z = tile_k_z / eTileReal;
                int tile_k = (tile_k_z % eTileReal) + xIndex;
                int bIndex = tile_k / transb_stride;
                int hwIndex = tile_k % transb_stride;
                int hIndex = (hwIndex / wUnit);
                int wIndex = (hwIndex % wUnit);
                int ohIndex = hIndex * dstUnit;
                int owIndex = wIndex * dstUnit;
                const float* postParameters = mPostParameters.data();
                const float* biasFloatPtr = (const float*)(bias + z * pack_stride);
                int ey = ALIMIN(ohIndex + dstUnit, oh) - ohIndex;
                int ex = ALIMIN(owIndex + dstUnit, ow) - owIndex;
                auto dstStart = dstOrigin + (owIndex + ohIndex * ow + bIndex * ob_stride) * pack_stride;
                auto srcStart =  srcOriginSegment + (tile_k - xIndex) * pack_stride;
                int count = ex * pack_stride;
                if (ex == dstUnit) {
                    auto dstZAddr = dstStart + z * dstZStep;
                    auto srcZ     = srcStart + z * srcTransZStep;
                    auto ocMidBuffer0 = midBuffer0 + tId * ocBufferOffset;
                    DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)ocMidBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
                    DestUnrollTransform[ey]((const float*)ocMidBuffer0, (float*)dstZAddr, biasFloatPtr, postParameters, pack, pack * ow, pack * dstUnit, pack);
                } else {
                    auto dstZAddr = dstStart + z * dstZStep;
                    auto srcZ     = srcStart + z * srcTransZStep;
                    auto ocMidBuffer0 = midBuffer0 + tId * ocBufferOffset;
                    auto ocMidBuffer1 = midBuffer1 + tId * ocBufferOffset;
                    DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)ocMidBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
                    DestUnrollTransform[ey]((const float*)ocMidBuffer0, (float*)ocMidBuffer1, biasFloatPtr, postParameters, pack, pack * dstUnit, pack * dstUnit, pack);
                    for (int yy = 0; yy < ey; ++yy) {
                        auto dstYAddr = dstZAddr + yy * ow * pack_stride;
                        auto srcYAddr = ocMidBuffer1 + yy * dstUnit * pack_stride;
                        ::memcpy(dstYAddr, srcYAddr, count);
                    }
                }
            }
            /*Dest Transform And Post Treat End*/
        };

    auto parallelOuterPackFreeFunction = [&](int tId) {
        int eTile = mConvPerfconfig.eTile;
        int hPackDynamic = mConvPerfconfig.hPack;

        auto _srcOrigin = mTempBuffer->host<uint8_t>() + tId * mTempBuffer->stride(0);
        auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
        auto midBuffer0 = mTransformMidBuffer->host<uint8_t>() + tId * mTransformMidBuffer->stride(0);
        auto midBuffer1 = midBuffer0 + midBuffer0Bytes;

        for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) {
            int xIndex  = (int)tIndex * eTile;
            int xReamin = totalCount - xIndex;
            int eTileReal = xReamin > eTile ? eTile : xReamin;

            /*Source Transform Begin*/
            const int bTransStride = wUnit * hUnit;
            const int ib_stride = iw * ih;
            const int pack_stride = pack * bytes;
            const int ICUnitStep    = ic_4 * eTileReal * pack;
            const int sourceZStep = iw * ih * batch * pack_stride;
            for (int z = 0; z < ic_4; z++) {
                for (int tile_k = xIndex; tile_k < xIndex + eTileReal; tile_k++) {
                    int bIndex = tile_k / bTransStride;
                    int hwIndex = tile_k % bTransStride;
                    int hIndex = (hwIndex / wUnit);
                    int wIndex = (hwIndex % wUnit);

                    int eTileNumber = tile_k - xIndex;
                    int iEpack = eTileNumber % ePack;
                    int iETile = eTileNumber - iEpack;
                    int ePackSegment = fmin(ePack, eTileReal - iETile);

                    int ihIndex = hIndex * dstUnit - padY;
                    int iwIndex = wIndex * dstUnit - padX;
                    int ey    = ALIMIN(ihIndex + srcUnit, ih) - ihIndex;
                    int sy    = ALIMAX(0, ihIndex) - ihIndex;
                    int ex    = ALIMIN(iwIndex + srcUnit, iw) - iwIndex;
                    int sx    = ALIMAX(0, iwIndex) - iwIndex;
                    int count = pack_stride * (ex - sx);

                    auto srcZ = srcOrigin + (iwIndex + ihIndex * iw + bIndex * ib_stride) * pack_stride + z * sourceZStep;
                    auto dstZ = _srcOrigin + (iETile * ic_4 + z * ePackSegment + iEpack) * pack_stride;

                    if (ex - sx == srcUnit && ey - sy == srcUnit) {

                        // Transform
                        mSourceUnrollTransform((const float*)srcZ, (float*)midBuffer1, iw * pack, pack, pack, pack * srcUnit);
                        mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);

                    } else {
                        // Extract
                        ::memset(midBuffer0, 0, mTransformMidBuffer->stride(1));
                        if (count > 0) {
                            for (int yy = sy; yy < ey; ++yy) {
                                auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack_stride;
                                auto src_yy = srcZ + (iw * yy + sx) * pack_stride;
                                ::memcpy(dst_yy, src_yy, count);
                            }
                        }

                        mSourceUnrollTransform((const float*)midBuffer0, (float*)midBuffer1, srcUnit * pack, pack, pack, pack * srcUnit);
                        mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
                    }
                }
            }
            /*Source Transform End*/
            //Multi
            int tLast = eTileReal % ePack;
            int tBlock = eTileReal - tLast;
            const int oc_hpack = UP_DIV(oc, hPackDynamic);
            const int oc_pack_coeff = hPackDynamic / pack;
            const int weightStride = mResource->mWeight->stride(0);

            auto threadParameters = Tile2MatMulParameters;
            auto threadParametersRemain = threadParameters;
            threadParameters[6] =  tBlock;
            threadParametersRemain[6] = tLast;
            threadParameters[3] = eTileReal * pack_stride;
            threadParametersRemain[3] = threadParameters[3];
            // copy pointer out
            auto MaxATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[ePack - 1];
            auto TailATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[tLast - 1];
            auto MaxATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[ePack - 1];
            auto TailATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[tLast - 1];
            auto MaxATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[ePack - 1];
            auto TailATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[tLast - 1];

            auto* _dstOrigin = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack * bytes;

            for (int i = 0; i < srcUnit2; ++i) {
                for (int t_oc_mul = 0; t_oc_mul < oc_hpack; ++t_oc_mul) {
                    int t_oc = t_oc_mul * oc_pack_coeff;
                    int ocValidPack = ALIMIN(t_oc + oc_pack_coeff, dc_4) - t_oc;

                    auto srcPtr = (_srcOrigin + i * ic_4 * eTileReal * pack * bytes);
                    auto _weightFloatPtr = (const float*)(weight + i * weightStride + (t_oc * ic_roundup * pack) * bytes);
                    auto _dstFloatPtr = (_dstOrigin + (i * dc_4 + t_oc) * eTileReal * pack * bytes);

#ifdef PROFILE_DETAIL
                    macs += eTileReal * (2 * ic) * (ocValidPack) * pack;
#endif

                    if (tBlock) {
                        switch (ocValidPack) {
                            case 1:
                                MaxATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                                break;
                            case 2:
                                MaxATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                                break;
                            case 3:
                                MaxATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
                                break;
                        }
                        srcPtr += tBlock * ic_4 * pack * bytes;
                        _dstFloatPtr += tBlock * pack * bytes;
                    }
                    if (tLast) {

                        switch (ocValidPack) {
                            case 1:
                                TailATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                                break;
                            case 2:
                                TailATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                                break;
                            case 3:
                                TailATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
                                break;
                        }
                    }

                }
            }
            /* Dest Transform And Post Treat Begin */
            const int transb_stride = wUnit * hUnit;
            const int ob_stride = ow * oh;
            const int srcTransZStep =  eTileReal * pack_stride;
            const int OCUnitStep =  eTileReal * pack * dc_4;
            const int dstZStep = ob_stride * batch * pack_stride;
            const auto srcOriginSegment = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack_stride;
            const float* postParameters = mPostParameters.data();
            auto DestUnrollTransform = mDestUnrollTransform.get();
            for (int z = 0; z < dc_4; ++z) {
                const float* biasFloatPtr = (const float*)(bias + z * pack_stride);
                for (int tile_k = xIndex; tile_k < xIndex + eTileReal; tile_k++) {
                    int bIndex = tile_k / transb_stride;
                    int hwIndex = tile_k % transb_stride;
                    int hIndex = (hwIndex / wUnit);
                    int wIndex = (hwIndex % wUnit);
                    int ohIndex = hIndex * dstUnit;
                    int owIndex = wIndex * dstUnit;
                    int ey = ALIMIN(ohIndex + dstUnit, oh) - ohIndex;
                    int ex = ALIMIN(owIndex + dstUnit, ow) - owIndex;
                    auto dstZPtr = dstOrigin + (owIndex + ohIndex * ow + bIndex * ob_stride) * pack_stride + z * dstZStep;
                    auto srcZPtr =  srcOriginSegment + (tile_k - xIndex) * pack_stride + z * srcTransZStep;
                    int count = ex * pack_stride;

                    if (ex == dstUnit) {
                        DestUnrollTransform[srcUnit]((const float*)srcZPtr, (float*)midBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
                        DestUnrollTransform[ey]((const float*)midBuffer0, (float*)dstZPtr, biasFloatPtr, postParameters, pack, pack * ow, pack * dstUnit, pack);
                    } else {
                        DestUnrollTransform[srcUnit]((const float*)srcZPtr, (float*)midBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
                        DestUnrollTransform[ey]((const float*)midBuffer0, (float*)midBuffer1, biasFloatPtr, postParameters,  pack, pack * dstUnit, pack * dstUnit, pack);

                        for (int yy = 0; yy < ey; ++yy) {
                            auto dstYAddr = dstZPtr + yy * ow * pack_stride;
                            auto srcYAddr = midBuffer1 + yy * dstUnit * pack_stride;
                            ::memcpy(dstYAddr, srcYAddr, count);
                        }
                    }
                }
            }
            /*Dest Transform And Post Treat End*/
        }

#ifdef PROFILE_DETAIL
        double gflops = (double)macs / 1000.0 / durationMul;
        MNN_PRINT(
            "conv winograd. mParallelInner:%d, tId:%d, lastTile:%d, srcUnit: %d, inside measure: sourceTrans1:%lu us, "
            "sourceTrans2:%lu us, packATime:%lu us, durationMul:%lu us,  destTrans:%lu us, total:%lu us. %.3f GFLOPS, "
            "macs:%lu\n",
            mConvPerfconfig.isParallelInner, tId, tileCount % ePack, srcUnit, durationSourceTrans1,
            durationSourceTrans2, packATime, durationMul, durationDestTrans1,
            durationSourceTrans1 + durationSourceTrans2 + packATime + durationMul + durationDestTrans1, gflops, macs);
#endif
    };

    if (mConvPerfconfig.isParallelInner) {

        for (int tIndex = 0; tIndex < tileCount; tIndex += 1) {
            MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                parallelInnerSourceFunction((int)tId, tIndex);
            }
            MNN_CONCURRENCY_END();

            MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                parallelInnerPackFreeMultiplyFunction((int)tId, tIndex);
            }
            MNN_CONCURRENCY_END();

            MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                parallelInnerDestFunction((int)tId, tIndex);
            }
            MNN_CONCURRENCY_END();
        }

    } else {
        MNN_CONCURRENCY_BEGIN(tId, threadNumber) {

            parallelOuterPackFreeFunction(tId);
        }
        MNN_CONCURRENCY_END();
    }
    return NO_ERROR;
}