in source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp [77:622]
ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto core = static_cast<CPUBackend*>(backend())->functions();
int pack = core->pack, bytes = core->bytes;
auto input = inputs[0];
auto output = outputs[0];
auto dstUnit = mA->length(1);
auto srcUnit = mA->length(0);
int ePackMax, lPack, hPack;
core->MNNGetMatMulPackMode(&ePackMax, &lPack, &hPack);
int ePack = mConvPerfconfig.ePack;
auto srcUnit2 = srcUnit * srcUnit;
auto alphaXStride = srcUnit * ePack * pack;
auto IC4alpha2Stride = srcUnit2 * ePack * pack;
int ow = output->width();
int oh = output->height();
int iw = input->width();
int ih = input->height();
int oc = output->channel();
int ic = input->channel();
int ic_roundup = ROUND_UP(ic, lPack);
int ic_4 = UP_DIV(input->channel(), pack);
int dc_4 = UP_DIV(output->channel(), pack);
int batch = input->batch();
int padY = mPadY;
int padX = mPadX;
auto wUnit = UP_DIV(ow, dstUnit);
auto hUnit = UP_DIV(oh, dstUnit);
auto totalCount = wUnit * hUnit * batch;
int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1);
int eRemain = totalCount % ePack;
int tileCount = UP_DIV(totalCount, mConvPerfconfig.eTile);
std::vector<size_t> parameters(7);
parameters[0] = eRemain * bytes;
parameters[1] = input->channel();
parameters[2] = output->channel();
parameters[3] = ePack * pack * bytes;
parameters[4] = 0;
parameters[5] = 0;
parameters[6] = 0;
std::vector<size_t> parametersRemain = parameters;
parametersRemain[3] = eRemain * pack * bytes;
std::vector<size_t> Tile2MatMulParameters = {
static_cast<size_t>(ePack * ic_4 * pack * bytes),
static_cast<size_t>(ic),
0,
0,
static_cast<size_t>(ic_roundup * mConvPerfconfig.hPack * bytes),
static_cast<size_t>(mConvPerfconfig.hPack * bytes),
0};
auto inputOrigin = input->host<uint8_t>();
auto outputOrigin = output->host<uint8_t>();
auto srcOrigin = inputOrigin;
auto dstOrigin = outputOrigin;
auto midBuffer0Bytes = srcUnit2 * pack * bytes;
bool allow_x86_bf16_winograd = true;
#ifdef MNN_USE_SSE
allow_x86_bf16_winograd = bytes != 2;
#endif
using ElementType = float;
auto weight = mResource->mWeight->host<uint8_t>();
auto bias = mResource->mBias->host<uint8_t>();
auto _srcOrigin = mTempBuffer->host<uint8_t>();
auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>());
auto midBuffer0 = mTransformMidBuffer->host<uint8_t>();
auto midBuffer1 = midBuffer0 + midBuffer0Bytes;
auto parallelInnerSourceFunction = [&](int tId, int tIndex) {
int eTile = mConvPerfconfig.eTile;
int hPackDynamic = mConvPerfconfig.hPack;
int ic_pack = ROUND_UP(ic, pack);
int xIndex = (int)tIndex * eTile;
int xReamin = totalCount - xIndex;
int eTileReal = xReamin > eTile ? eTile : xReamin;
/*Source Transform Begin*/
const int bTransStride = wUnit * hUnit;
const int ib_stride = iw * ih;
const int pack_stride = pack * bytes;
const int ICUnitStep = ic_4 * eTileReal * pack;
const int sourceZStep = ib_stride * batch * pack_stride;
const int IcBufferOffset = mTransformMidBuffer->stride(0);
for (int tile_k_z = tId; tile_k_z < ic_4 * eTileReal; tile_k_z += threadNumber) {
int z = tile_k_z / eTileReal;
int eTileNumber = tile_k_z % eTileReal;
int tile_k = eTileNumber + xIndex;
int bIndex = tile_k / bTransStride;
int hwIndex = tile_k % bTransStride;
int hIndex = (hwIndex / wUnit);
int wIndex = (hwIndex % wUnit);
int iEpack = eTileNumber % ePack;
int iETile = eTileNumber - iEpack;
int ePackSegment = fmin(ePack, eTileReal - iETile);
int ihIndex = hIndex * dstUnit - padY;
int iwIndex = wIndex * dstUnit - padX;
int ey = ALIMIN(ihIndex + srcUnit, ih) - ihIndex;
int sy = ALIMAX(0, ihIndex) - ihIndex;
int ex = ALIMIN(iwIndex + srcUnit, iw) - iwIndex;
int sx = ALIMAX(0, iwIndex) - iwIndex;
int count = pack_stride * (ex - sx);
auto srcZ = srcOrigin + (iwIndex + ihIndex * iw + bIndex * ib_stride) * pack_stride + z * sourceZStep;
auto dstZ = _srcOrigin + (iETile * ic_4 + z * ePackSegment + iEpack) * pack_stride;
if (ex - sx == srcUnit && ey - sy == srcUnit) {
auto icMidBuffer1 = midBuffer1 + tId * IcBufferOffset;
mSourceUnrollTransform((const float*)srcZ, (float*)icMidBuffer1, iw * pack, pack, pack, pack * srcUnit);
mSourceUnrollTransform((const float*)icMidBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
} else {
// Extract
auto icMidBuffer1 = midBuffer1 + tId * IcBufferOffset;
auto icMidBuffer0 = midBuffer0 + tId * IcBufferOffset;
::memset(icMidBuffer0, 0, mTransformMidBuffer->stride(1));
if (count > 0) {
for (int yy = sy; yy < ey; ++yy) {
auto dst_yy = icMidBuffer0 + (yy * srcUnit + sx) * pack_stride;
auto src_yy = srcZ + (iw * yy + sx) * pack_stride;
::memcpy(dst_yy, src_yy, count);
}
}
mSourceUnrollTransform((const float*)icMidBuffer0, (float*)icMidBuffer1, srcUnit * pack, pack, pack, pack * srcUnit);
mSourceUnrollTransform((const float*)icMidBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
}
}
};
auto parallelInnerPackFreeMultiplyFunction = [&](int tId, int tIndex) {
int eTile = mConvPerfconfig.eTile;
int hPackDynamic = mConvPerfconfig.hPack;
int xIndex = (int)tIndex * eTile;
int xReamin = totalCount - xIndex;
int eTileReal = xReamin > eTile ? eTile : xReamin;
int tLast = eTileReal % ePack;
int tBlock = eTileReal - tLast;
const int oc_hpack = UP_DIV(oc, hPackDynamic);
const int oc_pack_coeff = hPackDynamic / pack;
const int weightStride = mResource->mWeight->stride(0);
const int pack_stride = pack * bytes;
auto threadParameters = Tile2MatMulParameters;
auto threadParametersRemain = threadParameters;
threadParameters[6] = tBlock;
threadParametersRemain[6] = tLast;
threadParameters[3] = eTileReal * pack_stride;
threadParametersRemain[3] = threadParameters[3];
// copy pointer out
auto MaxATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[ePack - 1];
auto TailATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[tLast - 1];
auto MaxATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[ePack - 1];
auto TailATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[tLast - 1];
auto MaxATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[ePack - 1];
auto TailATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[tLast - 1];
auto* _dstOrigin = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack * bytes;
// srcUnit2, oc
for (int i_oc_src = tId; i_oc_src < srcUnit2 * oc_hpack; i_oc_src += threadNumber) {
int t_oc_mul = i_oc_src % oc_hpack;
int i = i_oc_src / oc_hpack;
int t_oc = t_oc_mul * oc_pack_coeff;
int ocValidPack = ALIMIN(t_oc + oc_pack_coeff, dc_4) - t_oc;
// calculate address
auto srcTemp = (_srcOrigin + i * ic_4 * eTileReal * pack * bytes);
auto _weightFloatPtr = (const float*)(weight + i * weightStride + (t_oc * ic_roundup * pack) * bytes);
auto _dstFloatPtr = (_dstOrigin + (i * dc_4 + t_oc) * eTileReal * pack * bytes);
#ifdef PROFILE_DETAIL
macs[tId] += eTileReal * (2 * ic) * (ocValidPack) * pack;
#endif
if (tBlock) {
switch (ocValidPack) {
case 1:
MaxATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
case 2:
MaxATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
case 3:
MaxATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
}
srcTemp += tBlock * ic_4 * pack * bytes;
_dstFloatPtr += tBlock * pack * bytes;
}
if (tLast) {
switch (ocValidPack) {
case 1:
TailATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
case 2:
TailATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
case 3:
TailATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcTemp, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
}
}
}
};
auto parallelInnerMultiplyFunction = [&](int tId, int tIndex) {
int xIndex = (int)tIndex * ePack;
int xReamin = totalCount - xIndex;
int xC = xReamin > ePack ? ePack : xReamin;
auto* _dstOrigin = _srcOrigin + xC * srcUnit2 * ic_4 * pack * bytes;
/*Source Transform End*/
// Multi
int32_t info[4];
info[0] = 1;
info[1] = xC;
info[2] = xC;
info[3] = 1;
int32_t el[4];
el[0] = xC;
el[1] = parameters[1];
el[2] = 0;
el[3] = 0;
if (xC == ePackMax) {
for (int i = tId; i < srcUnit2; i+=threadNumber) {
auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
core->MNNPackC4ForMatMul_A((float*)gemmBufferPtr, &srcTemp, info, el);
}
for (int i = tId; i < srcUnit2; i+=threadNumber) {
auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr);
}
} else {
for (int i = tId; i < srcUnit2; i+=threadNumber) {
auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
core->MNNPackC4ForMatMul_A((float*)gemmBufferPtr, &srcTemp, info, el);
}
for (int i = tId; i < srcUnit2; i+=threadNumber) {
auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes);
core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr, nullptr, nullptr);
}
}
};
/* Dest Transform And Post Treat Begin */
auto parallelInnerDestFunction = [&](int tId, int tIndex) {
auto DestUnrollTransform = mDestUnrollTransform.get();
int eTile = mConvPerfconfig.eTile;
int hPackDynamic = mConvPerfconfig.hPack;
int ic_pack = ROUND_UP(ic, pack);
int xIndex = (int)tIndex * eTile;
int xReamin = totalCount - xIndex;
int eTileReal = xReamin > eTile ? eTile : xReamin;
const int pack_stride = pack * bytes;
const int transb_stride = wUnit * hUnit;
const int ob_stride = ow * oh;
const int srcTransZStep = eTileReal * pack_stride;
const int OCUnitStep = eTileReal * pack * dc_4;
const int dstZStep = ob_stride * batch * pack_stride;
const auto ocBufferOffset = mTransformMidBuffer->stride(0);
const auto srcOriginSegment = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack_stride;
for (int tile_k_z = tId; tile_k_z < dc_4 * eTileReal; tile_k_z += threadNumber) {
int z = tile_k_z / eTileReal;
int tile_k = (tile_k_z % eTileReal) + xIndex;
int bIndex = tile_k / transb_stride;
int hwIndex = tile_k % transb_stride;
int hIndex = (hwIndex / wUnit);
int wIndex = (hwIndex % wUnit);
int ohIndex = hIndex * dstUnit;
int owIndex = wIndex * dstUnit;
const float* postParameters = mPostParameters.data();
const float* biasFloatPtr = (const float*)(bias + z * pack_stride);
int ey = ALIMIN(ohIndex + dstUnit, oh) - ohIndex;
int ex = ALIMIN(owIndex + dstUnit, ow) - owIndex;
auto dstStart = dstOrigin + (owIndex + ohIndex * ow + bIndex * ob_stride) * pack_stride;
auto srcStart = srcOriginSegment + (tile_k - xIndex) * pack_stride;
int count = ex * pack_stride;
if (ex == dstUnit) {
auto dstZAddr = dstStart + z * dstZStep;
auto srcZ = srcStart + z * srcTransZStep;
auto ocMidBuffer0 = midBuffer0 + tId * ocBufferOffset;
DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)ocMidBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
DestUnrollTransform[ey]((const float*)ocMidBuffer0, (float*)dstZAddr, biasFloatPtr, postParameters, pack, pack * ow, pack * dstUnit, pack);
} else {
auto dstZAddr = dstStart + z * dstZStep;
auto srcZ = srcStart + z * srcTransZStep;
auto ocMidBuffer0 = midBuffer0 + tId * ocBufferOffset;
auto ocMidBuffer1 = midBuffer1 + tId * ocBufferOffset;
DestUnrollTransform[srcUnit]((const float*)srcZ, (float*)ocMidBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
DestUnrollTransform[ey]((const float*)ocMidBuffer0, (float*)ocMidBuffer1, biasFloatPtr, postParameters, pack, pack * dstUnit, pack * dstUnit, pack);
for (int yy = 0; yy < ey; ++yy) {
auto dstYAddr = dstZAddr + yy * ow * pack_stride;
auto srcYAddr = ocMidBuffer1 + yy * dstUnit * pack_stride;
::memcpy(dstYAddr, srcYAddr, count);
}
}
}
/*Dest Transform And Post Treat End*/
};
auto parallelOuterPackFreeFunction = [&](int tId) {
int eTile = mConvPerfconfig.eTile;
int hPackDynamic = mConvPerfconfig.hPack;
auto _srcOrigin = mTempBuffer->host<uint8_t>() + tId * mTempBuffer->stride(0);
auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
auto midBuffer0 = mTransformMidBuffer->host<uint8_t>() + tId * mTransformMidBuffer->stride(0);
auto midBuffer1 = midBuffer0 + midBuffer0Bytes;
for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) {
int xIndex = (int)tIndex * eTile;
int xReamin = totalCount - xIndex;
int eTileReal = xReamin > eTile ? eTile : xReamin;
/*Source Transform Begin*/
const int bTransStride = wUnit * hUnit;
const int ib_stride = iw * ih;
const int pack_stride = pack * bytes;
const int ICUnitStep = ic_4 * eTileReal * pack;
const int sourceZStep = iw * ih * batch * pack_stride;
for (int z = 0; z < ic_4; z++) {
for (int tile_k = xIndex; tile_k < xIndex + eTileReal; tile_k++) {
int bIndex = tile_k / bTransStride;
int hwIndex = tile_k % bTransStride;
int hIndex = (hwIndex / wUnit);
int wIndex = (hwIndex % wUnit);
int eTileNumber = tile_k - xIndex;
int iEpack = eTileNumber % ePack;
int iETile = eTileNumber - iEpack;
int ePackSegment = fmin(ePack, eTileReal - iETile);
int ihIndex = hIndex * dstUnit - padY;
int iwIndex = wIndex * dstUnit - padX;
int ey = ALIMIN(ihIndex + srcUnit, ih) - ihIndex;
int sy = ALIMAX(0, ihIndex) - ihIndex;
int ex = ALIMIN(iwIndex + srcUnit, iw) - iwIndex;
int sx = ALIMAX(0, iwIndex) - iwIndex;
int count = pack_stride * (ex - sx);
auto srcZ = srcOrigin + (iwIndex + ihIndex * iw + bIndex * ib_stride) * pack_stride + z * sourceZStep;
auto dstZ = _srcOrigin + (iETile * ic_4 + z * ePackSegment + iEpack) * pack_stride;
if (ex - sx == srcUnit && ey - sy == srcUnit) {
// Transform
mSourceUnrollTransform((const float*)srcZ, (float*)midBuffer1, iw * pack, pack, pack, pack * srcUnit);
mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
} else {
// Extract
::memset(midBuffer0, 0, mTransformMidBuffer->stride(1));
if (count > 0) {
for (int yy = sy; yy < ey; ++yy) {
auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack_stride;
auto src_yy = srcZ + (iw * yy + sx) * pack_stride;
::memcpy(dst_yy, src_yy, count);
}
}
mSourceUnrollTransform((const float*)midBuffer0, (float*)midBuffer1, srcUnit * pack, pack, pack, pack * srcUnit);
mSourceUnrollTransform((const float*)midBuffer1, (float*)dstZ, srcUnit * pack, ICUnitStep, pack, ICUnitStep * srcUnit);
}
}
}
/*Source Transform End*/
//Multi
int tLast = eTileReal % ePack;
int tBlock = eTileReal - tLast;
const int oc_hpack = UP_DIV(oc, hPackDynamic);
const int oc_pack_coeff = hPackDynamic / pack;
const int weightStride = mResource->mWeight->stride(0);
auto threadParameters = Tile2MatMulParameters;
auto threadParametersRemain = threadParameters;
threadParameters[6] = tBlock;
threadParametersRemain[6] = tLast;
threadParameters[3] = eTileReal * pack_stride;
threadParametersRemain[3] = threadParameters[3];
// copy pointer out
auto MaxATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[ePack - 1];
auto TailATileMatMulOC16Function = core->MNNPackedMatMulOC16Functions[tLast - 1];
auto MaxATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[ePack - 1];
auto TailATileMatMulOC32Function = core->MNNPackedMatMulOC32Functions[tLast - 1];
auto MaxATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[ePack - 1];
auto TailATileMatMulOC48Function = core->MNNPackedMatMulOC48Functions[tLast - 1];
auto* _dstOrigin = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack * bytes;
for (int i = 0; i < srcUnit2; ++i) {
for (int t_oc_mul = 0; t_oc_mul < oc_hpack; ++t_oc_mul) {
int t_oc = t_oc_mul * oc_pack_coeff;
int ocValidPack = ALIMIN(t_oc + oc_pack_coeff, dc_4) - t_oc;
auto srcPtr = (_srcOrigin + i * ic_4 * eTileReal * pack * bytes);
auto _weightFloatPtr = (const float*)(weight + i * weightStride + (t_oc * ic_roundup * pack) * bytes);
auto _dstFloatPtr = (_dstOrigin + (i * dc_4 + t_oc) * eTileReal * pack * bytes);
#ifdef PROFILE_DETAIL
macs += eTileReal * (2 * ic) * (ocValidPack) * pack;
#endif
if (tBlock) {
switch (ocValidPack) {
case 1:
MaxATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
case 2:
MaxATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
case 3:
MaxATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParameters.data(), nullptr, nullptr);
break;
}
srcPtr += tBlock * ic_4 * pack * bytes;
_dstFloatPtr += tBlock * pack * bytes;
}
if (tLast) {
switch (ocValidPack) {
case 1:
TailATileMatMulOC16Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
case 2:
TailATileMatMulOC32Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
case 3:
TailATileMatMulOC48Function((float*)_dstFloatPtr, (const float*)srcPtr, _weightFloatPtr, threadParametersRemain.data(), nullptr, nullptr);
break;
}
}
}
}
/* Dest Transform And Post Treat Begin */
const int transb_stride = wUnit * hUnit;
const int ob_stride = ow * oh;
const int srcTransZStep = eTileReal * pack_stride;
const int OCUnitStep = eTileReal * pack * dc_4;
const int dstZStep = ob_stride * batch * pack_stride;
const auto srcOriginSegment = _srcOrigin + eTileReal * srcUnit2 * ic_4 * pack_stride;
const float* postParameters = mPostParameters.data();
auto DestUnrollTransform = mDestUnrollTransform.get();
for (int z = 0; z < dc_4; ++z) {
const float* biasFloatPtr = (const float*)(bias + z * pack_stride);
for (int tile_k = xIndex; tile_k < xIndex + eTileReal; tile_k++) {
int bIndex = tile_k / transb_stride;
int hwIndex = tile_k % transb_stride;
int hIndex = (hwIndex / wUnit);
int wIndex = (hwIndex % wUnit);
int ohIndex = hIndex * dstUnit;
int owIndex = wIndex * dstUnit;
int ey = ALIMIN(ohIndex + dstUnit, oh) - ohIndex;
int ex = ALIMIN(owIndex + dstUnit, ow) - owIndex;
auto dstZPtr = dstOrigin + (owIndex + ohIndex * ow + bIndex * ob_stride) * pack_stride + z * dstZStep;
auto srcZPtr = srcOriginSegment + (tile_k - xIndex) * pack_stride + z * srcTransZStep;
int count = ex * pack_stride;
if (ex == dstUnit) {
DestUnrollTransform[srcUnit]((const float*)srcZPtr, (float*)midBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
DestUnrollTransform[ey]((const float*)midBuffer0, (float*)dstZPtr, biasFloatPtr, postParameters, pack, pack * ow, pack * dstUnit, pack);
} else {
DestUnrollTransform[srcUnit]((const float*)srcZPtr, (float*)midBuffer0, nullptr, nullptr, OCUnitStep, dstUnit * pack, srcUnit * OCUnitStep, pack);
DestUnrollTransform[ey]((const float*)midBuffer0, (float*)midBuffer1, biasFloatPtr, postParameters, pack, pack * dstUnit, pack * dstUnit, pack);
for (int yy = 0; yy < ey; ++yy) {
auto dstYAddr = dstZPtr + yy * ow * pack_stride;
auto srcYAddr = midBuffer1 + yy * dstUnit * pack_stride;
::memcpy(dstYAddr, srcYAddr, count);
}
}
}
}
/*Dest Transform And Post Treat End*/
}
#ifdef PROFILE_DETAIL
double gflops = (double)macs / 1000.0 / durationMul;
MNN_PRINT(
"conv winograd. mParallelInner:%d, tId:%d, lastTile:%d, srcUnit: %d, inside measure: sourceTrans1:%lu us, "
"sourceTrans2:%lu us, packATime:%lu us, durationMul:%lu us, destTrans:%lu us, total:%lu us. %.3f GFLOPS, "
"macs:%lu\n",
mConvPerfconfig.isParallelInner, tId, tileCount % ePack, srcUnit, durationSourceTrans1,
durationSourceTrans2, packATime, durationMul, durationDestTrans1,
durationSourceTrans1 + durationSourceTrans2 + packATime + durationMul + durationDestTrans1, gflops, macs);
#endif
};
if (mConvPerfconfig.isParallelInner) {
for (int tIndex = 0; tIndex < tileCount; tIndex += 1) {
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
parallelInnerSourceFunction((int)tId, tIndex);
}
MNN_CONCURRENCY_END();
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
parallelInnerPackFreeMultiplyFunction((int)tId, tIndex);
}
MNN_CONCURRENCY_END();
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
parallelInnerDestFunction((int)tId, tIndex);
}
MNN_CONCURRENCY_END();
}
} else {
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
parallelOuterPackFreeFunction(tId);
}
MNN_CONCURRENCY_END();
}
return NO_ERROR;
}