in source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp [423:718]
ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
CPUConvolution::onResize(inputs, outputs);
auto input = inputs[0];
auto weight = inputs[1];
Tensor *bias = nullptr;
if (inputs.size() > 2) {
bias = inputs[2];
}
auto core = static_cast<CPUBackend *>(backend())->functions();
int bytes = core->bytes;
float weightBytes = bytes;
int unit = core->pack;
int matmulBytes = bytes;
if (core->matmulBytes != 0) {
matmulBytes = core->matmulBytes;
}
auto packA = core->MNNPackC4ForMatMul_A;
int eP, lP, hP;
getPackParameter(&eP, &lP, &hP, core);
auto matmulUnit = core->MNNPackedMatMul;
auto matmulRemain = core->MNNPackedMatMulRemain;
const uint8_t* dequantAlpha = nullptr;
const uint8_t* dequantBias = nullptr;
auto ic = input->channel();
auto icC4 = UP_DIV(ic, unit);
auto L = ic * mCommon->kernelY() * mCommon->kernelX();
auto tileC = std::max(unit, hP);
int blockSize = L;
int blockNum = 1;
float halfStride = 1;
size_t weightStride = 0;
#ifdef MNN_LOW_MEMORY
if (mResource && mResource->mDequantize.bits <= 8) {
MNN_ASSERT(mResource->mDequantize.bits == 8);
DenseConvolutionTiledExecutor::selectLowMemoryMatmulFunc(&matmulUnit, &matmulRemain, &weightBytes, mResource->mDequantize.bits, core);
int scaleSize = mResource->mDequantize.mScaleBias->size() / (2 * bytes);
blockNum = scaleSize / (mResource->hU * mResource->hP);
blockSize /= blockNum;
dequantAlpha = mResource->mDequantize.mScaleBias->host<uint8_t>();
dequantBias = dequantAlpha + scaleSize * bytes;
weightStride = (L - blockSize) * hP;
}
#endif
auto kernel_width = mCommon->kernelX();
auto kernel_height = mCommon->kernelY();
auto output = outputs[0];
auto batch = output->batch();
int threadNumber = ((CPUBackend *)backend())->threadNumber();
int LRoundup = ROUND_UP(L, lP);
int LRoundupC4 = UP_DIV(LRoundup, unit);
auto outputChannel = output->channel();
auto oC4 = UP_DIV(outputChannel, tileC);
auto ocUp4 = ROUND_UP(outputChannel, hP);
auto kernelSize = mCommon->kernelX() * mCommon->kernelY();
ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParameters, mCommon, input, output, mPadX, mPadY, core, nullptr);
mTempBufferTranspose.buffer().type = halide_type_of<uint8_t>();
mTempBufferTranspose.buffer().dimensions = 2;
mTempBufferTranspose.buffer().dim[0].extent = threadNumber;
mTempBufferTranspose.buffer().dim[1].extent = UP_DIV(L, lP) * lP * eP * matmulBytes;
TensorUtils::setLinearLayout(&mTempBufferTranspose);
auto plane = mIm2ColParameters.ow * mIm2ColParameters.oh * batch;
int tileCount = UP_DIV(plane, eP);
mConvPerfconfig = bestTileConvolutionConfig(mCommon, input, output, threadNumber, backend());
bool success = backend()->onAcquireBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
if (!success) {
return OUT_OF_MEMORY;
}
auto bufferAlloc = static_cast<CPUBackend *>(backend())->getBufferAllocator();
auto maxLine = UP_DIV(eP, mIm2ColParameters.ow) + 1;
auto tempPtr = bufferAlloc->alloc(kernelSize * maxLine * threadNumber * (4 * sizeof(int32_t) + sizeof(float *)));
if (tempPtr.invalid()) {
return OUT_OF_MEMORY;
}
backend()->onReleaseBuffer(&mTempBufferTranspose, Backend::DYNAMIC);
bufferAlloc->free(tempPtr);
auto postParameters = getPostParameters();
mFunction.first = threadNumber;
if (mConvPerfconfig.isParallelInner) {
auto rt = static_cast<const CPURuntime*>(backend()->getRuntime());
std::vector<int> ocC4ParralSize(threadNumber + 1);
ocC4ParralSize[0] = 0;
static_cast<CPUBackend *>(backend())->computeDivideSizes(oC4, ocC4ParralSize.data()+1);
mFunction.second = [=](int placeholder) {
const float* biasPtr = bias ? bias->host<float>() : nullptr;
auto gemmBuffer = mTempBufferTranspose.host<uint8_t>() + mTempBufferTranspose.stride(0) * 0;
auto srcPtr = (float const **)(tempPtr.ptr() + 0 * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
auto el = (int32_t *)(srcPtr + kernelSize * maxLine);
auto weightPtr = weight->host<uint8_t>();
constexpr int InfoSize = 4;
int32_t shapeInfo[InfoSize];
int32_t* info = shapeInfo;
info[1] = mIm2ColParameters.iw * mIm2ColParameters.ih * batch;
info[2] = eP;
info[3] = mIm2ColParameters.strideX;
size_t shapeParameters[PARAMETERSIZE];
size_t* parameters = shapeParameters;
parameters[0] = eP * bytes;
parameters[1] = blockSize;
parameters[2] = outputChannel;
parameters[3] = plane * unit * bytes;
parameters[4] = 0;
parameters[5] = weightStride; // Only used when block quant
parameters[6] = 0;
auto dstOrigin = output->host<uint8_t>();
auto srcOrigin = input->host<uint8_t>();
std::vector<int> im2colParallelSize(threadNumber + 1);
im2colParallelSize[0] = 0;
for (int x = 0; x < tileCount; x += 1) {
int start = (int)x * eP;
int remain = plane - start;
int xC = remain > eP ? eP : remain;
auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo(srcPtr, el, start, xC, mIm2ColParameters, srcOrigin, bytes);
int number = res.first;
bool needZero = res.second;
info[0] = number;
if (needZero || lP != 1) {
::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0));
}
info[0] = 1;
int hw4Stride = info[1] * unit * bytes;
static_cast<CPUBackend *>(backend())->computeDivideSizes(number * icC4, im2colParallelSize.data() + 1);
im2colParallelSize[0] = 0;
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
int threadEL[4];
int ticSta = im2colParallelSize[tId];
int ticEnd = im2colParallelSize[tId+1];
for(int tic_inumber = ticSta; tic_inumber < ticEnd; tic_inumber++) {
int inumber = tic_inumber / icC4;
int t_ic = tic_inumber % icC4;
memcpy(threadEL, el + 4 * inumber, 4 * sizeof(int));
threadEL[1] = std::min(ic - (t_ic * unit), unit);
const float* source = (const float*)((const uint8_t*)(srcPtr[inumber]) + t_ic * hw4Stride);
auto gemmDest = gemmBuffer + t_ic * unit * eP * bytes;
packA((float *)gemmDest, &source, info, threadEL);
}
}
MNN_CONCURRENCY_END();
if (xC == eP) {
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
size_t paraParameters[PARAMETERSIZE];
memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t));
for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) {
int ocIndex = t_oc * tileC;
auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes);
auto _weightFloatPtr = reinterpret_cast<const float*>(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes));
auto _biasFloatPtr = reinterpret_cast<const float*>(reinterpret_cast<const uint8_t*>(biasPtr) + ocIndex * bytes);
paraParameters[2] = std::min(outputChannel - ocIndex, tileC);
auto k = reinterpret_cast<const uint8_t*>(dequantAlpha + ocIndex * bytes);
auto b = reinterpret_cast<const uint8_t*>(dequantBias + ocIndex * bytes);
const float* relufp32 = nullptr;
const float* exeBiasPtr = nullptr;
int finishedL = 0;
int wquantStride = 0;
auto _weightPtr = reinterpret_cast<const int8_t*>(_weightFloatPtr);
uint8_t* _APtr = reinterpret_cast<uint8_t*>(gemmBuffer);
for (int bk = 0; bk < blockNum; ++bk) {
paraParameters[6] = bk;
if (bk == blockNum - 1) {
relufp32 = postParameters.data();
exeBiasPtr = _biasFloatPtr;
}
finishedL = blockSize * bk;
wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
matmulUnit(_dstFloatPtr, (float*)(_APtr + eP * finishedL * bytes), (float*)(_weightPtr + wquantStride), paraParameters, relufp32, exeBiasPtr, (float*)(k + bk * ocUp4 * bytes), (float*)(b + bk * ocUp4 * bytes));
}
}
}
MNN_CONCURRENCY_END();
} else {
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
size_t paraParameters[PARAMETERSIZE];
memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t));
for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) {
int ocIndex = t_oc * tileC;
auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes);
auto _weightFloatPtr = reinterpret_cast<const float*>(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes));
auto _biasFloatPtr = reinterpret_cast<const float*>(reinterpret_cast<const uint8_t*>(biasPtr) + ocIndex * bytes);
paraParameters[2] = std::min(outputChannel - ocIndex, tileC);
auto k = reinterpret_cast<const uint8_t*>(dequantAlpha + ocIndex * bytes);
auto b = reinterpret_cast<const uint8_t*>(dequantBias + ocIndex * bytes);
const float* relufp32 = nullptr;
const float* exeBiasPtr = nullptr;
int finishedL = 0;
int wquantStride = 0;
const int8_t* _weightPtr = reinterpret_cast<const int8_t*>(_weightFloatPtr);
uint8_t* _APtr = reinterpret_cast<uint8_t*>(gemmBuffer);
for (int bk = 0; bk < blockNum; ++bk) {
paraParameters[6] = bk;
if (bk == blockNum - 1) {
relufp32 = postParameters.data();
exeBiasPtr = _biasFloatPtr;
}
finishedL = blockSize * bk;
wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
matmulRemain(_dstFloatPtr, (float*)(_APtr + eP * finishedL * bytes), (float*)(_weightPtr + wquantStride), xC, paraParameters, relufp32, exeBiasPtr, (float*)(k + bk * ocUp4 * bytes), (float*)(b + bk * ocUp4 * bytes));
}
}
}
MNN_CONCURRENCY_END();
}
}
};
} else {
std::vector<int> divides(threadNumber + 1);
divides[0] = 0;
static_cast<CPUBackend *>(backend())->computeDivideSizes(tileCount, divides.data() + 1);
mFunction.second = [=](int tId) {
const float* biasPtr = bias ? bias->host<float>() : nullptr;
auto gemmBuffer = mTempBufferTranspose.host<uint8_t>() + mTempBufferTranspose.stride(0) * tId;
auto srcPtr = (float const **)(tempPtr.ptr() + tId * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *)));
auto el = (int32_t *)(srcPtr + kernelSize * maxLine);
auto weightPtr = weight->host<float>();
int32_t info[4];
info[1] = mIm2ColParameters.iw * mIm2ColParameters.ih * batch;
info[2] = eP;
info[3] = mIm2ColParameters.strideX;
size_t parameters[PARAMETERSIZE];
parameters[0] = eP * bytes;
parameters[1] = blockSize;
parameters[2] = outputChannel;
parameters[3] = plane * unit * bytes;
parameters[4] = 0;
parameters[5] = weightStride; // Only used when block quant
parameters[6] = 0;
auto dstOrigin = output->host<uint8_t>();
auto srcOrigin = input->host<uint8_t>();
int tEnd = divides[tId+1];
int tStart = divides[tId];
for (int x = (int)tStart; x < tEnd; ++x) {
int start = (int)x * eP;
int remain = plane - start;
int xC = remain > eP ? eP : remain;
auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo(srcPtr, el, start, xC, mIm2ColParameters, srcOrigin, bytes);
auto number = res.first;
bool needZero = res.second;
info[0] = number;
if (needZero || lP != 1) {
::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0));
}
if (number > 0) {
packA((float *)gemmBuffer, srcPtr, info, el);
}
int finishedL = 0;
int wquantStride = 0;
int8_t* _weightPtr = reinterpret_cast<int8_t*>(weightPtr);
auto _dstFloatPtr = reinterpret_cast<float*>(dstOrigin + start * unit * bytes);
const float* relufp32 = nullptr;
const float* exeBiasPtr = nullptr;
if (xC == eP) {
// matmulUnit(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, parameters, postParameters.data(), biasPtr, k, b);
for (int bk = 0; bk < blockNum; ++bk) {
parameters[6] = bk;
if (bk == blockNum - 1) {
relufp32 = postParameters.data();
exeBiasPtr = biasPtr;
}
finishedL = blockSize * bk;
wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
matmulUnit(_dstFloatPtr, (float*)(gemmBuffer + bytes * eP * finishedL), (float*)(_weightPtr + wquantStride), parameters, relufp32, exeBiasPtr, (float*)(dequantAlpha + bk * ocUp4 * bytes), (float*)(dequantBias + bk * ocUp4 * bytes));
}
} else {
for (int bk = 0; bk < blockNum; ++bk) {
parameters[6] = bk;
if (bk == blockNum - 1) {
relufp32 = postParameters.data();
exeBiasPtr = biasPtr;
}
finishedL = blockSize * bk;
wquantStride = static_cast<int32_t>(blockSize * bk * hP * halfStride);
matmulRemain(_dstFloatPtr, (float*)(gemmBuffer + eP * bytes * finishedL), (float*)(_weightPtr + wquantStride), xC, parameters, relufp32, exeBiasPtr, (float*)(dequantAlpha + bk * ocUp4 * bytes), (float*)(dequantBias + bk * ocUp4 * bytes ));
}
// matmulRemain(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, xC, parameters, postParameters.data(), biasPtr, k, b);
}
}
};
}
return NO_ERROR;
}