in source/backend/cpu/CPURaster.cpp [818:1130]
virtual ErrorCode onExecute(const std::vector<Tensor *> &originInputs, const std::vector<Tensor *> &originOutputs) override {
auto cpubackend = static_cast<CPUBackend*>(backend());
auto precision = cpubackend->precisionMode();
auto threadNumber = cpubackend->threadNumber();
if (mLoop->initCommand() != nullptr) {
for (int i=0; i<mLoop->initCommand()->size(); ++i) {
auto cmd = mLoop->initCommand()->GetAs<RegionCommand>(i);
if (cmd->op() == nullptr) {
auto output = mStack[cmd->indexes()->data()[0]];
::memset(output->host<void>(), 0, cpubackend->getTensorSize(output) * cpubackend->functions()->bytes);
} else {
Tensor::InsideDescribe::Region reg;
auto srcView = cmd->view()->GetAs<View>(1);
auto dstView = cmd->view()->GetAs<View>(0);
::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
::memcpy(reg.dst.stride, dstView->stride()->data(), 3 * sizeof(int32_t));
auto input = mStack[cmd->indexes()->data()[1]];
auto inputSize = input->elementSize();
auto output = mStack[cmd->indexes()->data()[0]];
auto bytes = input->getType().bytes();
if (halide_type_float == input->getType().code) {
bytes = cpubackend->functions()->bytes;
}
_blit(reg, bytes, input->host<uint8_t>(), output->host<uint8_t>(), false);
}
}
}
if (1 == mLoop->commands()->size()) {
auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
auto op = cmd->op();
if (OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0) {
// For Gather / Single Unary
auto index0 = cmd->iterIndexes()->data()[0];
auto index1 = cmd->iterIndexes()->data()[1];
int32_t iter = 0;
int32_t* iter0 = &iter;
int32_t* iter1 = &iter;
int32_t iter0Stride = 0;
int32_t iter1Stride = 0;
if (index0 >= 0) {
iter0 = originInputs[index0]->host<int32_t>();
iter0Stride = 1;
}
if (index1 >= 0) {
iter1 = originInputs[index1]->host<int32_t>();
iter1Stride = 1;
}
Tensor::InsideDescribe::Region reg;
auto srcView = cmd->view()->GetAs<View>(1);
auto dstView = cmd->view()->GetAs<View>(0);
::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
::memcpy(reg.dst.stride, dstView->stride()->data(), 3 * sizeof(int32_t));
auto input = mStack[cmd->indexes()->data()[1]];
auto inputSize = input->usize() / input->buffer().type.bytes();
auto output = mStack[cmd->indexes()->data()[0]];
auto outputSize = output->usize() / output->buffer().type.bytes();
auto bytes = input->getType().bytes();
if (halide_type_float == input->getType().code) {
bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
}
auto step0 = cmd->steps()->data()[0];
auto step1 = cmd->steps()->data()[1];
auto loopNumber = mLoop->loopNumber();
for (; iter<loopNumber; ++iter) {
auto srcIter = *(iter1 + iter1Stride * iter);
auto dstIter = *(iter0 + iter0Stride * iter);
auto srcOffset = srcIter * step1 + srcView->offset();
auto dstOffset = dstIter * step0 + dstView->offset();
if (dstOffset >= 0 && dstOffset < outputSize) {
if (srcOffset >= 0 && srcOffset < inputSize) {
_blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, false);
} else {
_zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset);
}
}
}
return NO_ERROR;
}
}
auto bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
auto func = [&](int iter, int tId) {
int fuseOutputStride[3];
const int32_t* outputStride = nullptr;
auto fuseBuffer = mFuseBuffer + mMaxFuseBufferSize * tId;
for (int index=0; index<mLoop->commands()->size(); ++index) {
auto cmd = mLoop->commands()->GetAs<RegionCommand>(index);
auto blit = _selectUnitProc(bytes, cmd->view()->GetAs<View>(1)->stride()->data()[2], 1);
auto op = cmd->op();
int iterIndexsize = cmd->iterIndexes()->size();
if (cmd->fuse() >= 0) {
outputStride = fuseOutputStride;
auto cmdSize = cmd->size()->data();
fuseOutputStride[0] = cmdSize[1] * cmdSize[2];
fuseOutputStride[1] = cmdSize[2];
fuseOutputStride[2] = 1;
} else {
// Loop Op's command's first index must be output
outputStride = cmd->view()->GetAs<View>(0)->stride()->data();
}
halide_type_t inputType;
for (int v=0; v<iterIndexsize; ++v) {
auto tensorIndex = cmd->indexes()->data()[v];
auto tensor = mStack[tensorIndex];
auto iterIndex = cmd->iterIndexes()->data()[v];
auto offset = iter;
if (1 == v) {
inputType = tensor->getType();
}
if (iterIndex >= 0) {
offset = mStack[iterIndex]->host<int32_t>()[iter];
}
auto view = cmd->view()->GetAs<View>(v);
offset = offset * cmd->steps()->data()[v] + view->offset();
mContainer[tId].stackPtr[tensorIndex] = tensor->host<uint8_t>() + offset * bytes;
MNN_ASSERT(nullptr != tensor->host<uint8_t>());
}
auto dstOrigin = (uint8_t*)mContainer[tId].stackPtr[cmd->indexes()->data()[0]];
auto dst = dstOrigin;
if (cmd->fuse() >= 0) {
dst = fuseBuffer.ptr();
}
do {
if (OpType_UnaryOp == op->type()) {
auto src = (uint8_t*)mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
if (nullptr == op->main()) {
// Copy
Tensor::InsideDescribe::Region reg;
auto srcView = cmd->view()->GetAs<View>(1);
auto dstView = cmd->view()->GetAs<View>(0);
::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
::memcpy(reg.dst.stride, outputStride, 3 * sizeof(int32_t));
auto step0 = cmd->steps()->data()[0];
auto step1 = cmd->steps()->data()[1];
auto loopNumber = mLoop->loopNumber();
_blit(reg, bytes, (const uint8_t*)src, (uint8_t*)dst, false);
break;
}
auto proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectUnaryFunctionForFloat(op->main_as_UnaryOp()->opType(), static_cast<CPUBackend*>(backend())->precisionMode());
auto lastS = cmd->size()->data()[2];
if (lastS == 1 || cmd->view()->GetAs<View>(1)->stride()->data()[2] == 1) {
for (int z=0; z<cmd->size()->data()[0]; ++z) {
auto srcZ = src + z * cmd->view()->GetAs<View>(1)->stride()->data()[0] * bytes;
auto dstZ = dst + z * outputStride[0] * bytes;
for (int y=0; y<cmd->size()->data()[1]; ++y) {
auto srcY = srcZ + y * cmd->view()->GetAs<View>(1)->stride()->data()[1] * bytes;
auto dstY = dstZ + y * outputStride[1] * bytes;
proc(dstY, srcY, lastS);
}
}
} else {
// Blit to cache
auto srcCache = mCacheBuffer.ptr() + mMaxCacheSize * tId;
for (int z=0; z<cmd->size()->data()[0]; ++z) {
auto srcZ = src + z * cmd->view()->GetAs<View>(1)->stride()->data()[0] * bytes;
auto dstZ = dst + z * outputStride[0] * bytes;
for (int y=0; y<cmd->size()->data()[1]; ++y) {
auto srcY = srcZ + y * cmd->view()->GetAs<View>(1)->stride()->data()[1] * bytes;
auto dstY = dstZ + y * outputStride[1] * bytes;
blit(srcCache, srcY, lastS, cmd->view()->GetAs<View>(1)->stride()->data()[2], 1);
proc(dstY, srcCache, lastS);
}
}
}
continue;
}
if (OpType_MatMul == op->type()) {
// TODO: Don't support fuse for matmul currently
const float* APtr = nullptr;
const float* BPtr = nullptr;
const float* BiasPtr = nullptr;
float* CPtr = (float*)dst;
auto exe = static_cast<CPUMatMul*>(mContainer[tId].exe[index].get());
APtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
BPtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[2]];
if (iterIndexsize > 3) {
BiasPtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[3]];
}
exe->execute(APtr, BPtr, CPtr, BiasPtr);
break;
}
if (OpType_BinaryOp == op->type()) {
auto src0 = mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
MNNBinaryExecute proc;
if (inputType.code == halide_type_float) {
proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(op->main_as_BinaryOp()->opType());
} else {
MNN_ASSERT(inputType.code == halide_type_int);
proc = CPUBinary::selectForInt(op->main_as_BinaryOp()->opType());
}
auto lastS = cmd->size()->data()[2];
auto stride0 = outputStride;
auto stride1 = cmd->view()->GetAs<View>(1)->stride()->data();
MNN_ASSERT(stride0[2] == 1);
auto src1 = mContainer[tId].stackPtr[cmd->indexes()->data()[2]];
auto stride2 = cmd->view()->GetAs<View>(2)->stride()->data();
auto blit1 = _selectUnitProc(bytes, stride1[2], 1);
auto blit2 = _selectUnitProc(bytes, stride2[2], 1);
if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) {
for (int z=0; z<cmd->size()->data()[0]; ++z) {
auto src0Z = src0 + z * stride1[0] * bytes;
auto src1Z = src1 + z * stride2[0] * bytes;
auto dstZ = dst + z * stride0[0] * bytes;
for (int y=0; y<cmd->size()->data()[1]; ++y) {
auto src0Y = src0Z + y * stride1[1] * bytes;
auto src1Y = src1Z + y * stride2[1] * bytes;
auto dstY = dstZ + y * stride0[1] * bytes;
proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1);
}
}
} else {
auto cache0 = mCacheBuffer.ptr() + mMaxCacheSize * tId;
auto cache1 = cache0 + cmd->size()->data()[2] * bytes;
for (int z=0; z<cmd->size()->data()[0]; ++z) {
auto src0Z = src0 + z * stride1[0] * bytes;
auto src1Z = src1 + z * stride2[0] * bytes;
auto dstZ = dst + z * stride0[0] * bytes;
for (int y=0; y<cmd->size()->data()[1]; ++y) {
auto src0Y = src0Z + y * stride1[1] * bytes;
auto src1Y = src1Z + y * stride2[1] * bytes;
auto dstY = dstZ + y * stride0[1] * bytes;
blit1(cache0, src0Y, cmd->size()->data()[2], stride1[2], 1);
blit2(cache1, src1Y, cmd->size()->data()[2], stride2[2], 1);
proc(dstY, cache0, cache1, cmd->size()->data()[2], -1);
}
}
}
break;
}
} while(false);
if (dst != dstOrigin) {
MNN_ASSERT(bytes == 4);
// Currently only support add and float32
auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
auto srcF = (const float*)dst;
auto dstF = (float*)dstOrigin;
int sizeZ = cmd->size()->data()[0];
int sizeY = cmd->size()->data()[1];
int sizeX = cmd->size()->data()[2];
if (cmd->op()->type() == OpType_MatMul) {
auto proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(cmd->fuse());
proc(dstF, dstF, srcF, sizeZ * sizeX, -1);
continue;
}
switch (cmd->fuse()) {
case BinaryOpOperation_ADD:
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcF + z * outputStride[0];
auto dstZ = dstF + z * dstStride[0];
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + y * outputStride[1];
auto dstY = dstZ + y * dstStride[1];
for (int x=0; x<sizeX; ++x) {
auto dstOffset = x * dstStride[2];
dstY[dstOffset] = dstY[dstOffset] + srcY[x];
}
}
}
break;
case BinaryOpOperation_MUL:
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcF + z * dstStride[0];
auto dstZ = dstF + z * outputStride[0];
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + z * dstStride[1];
auto dstY = dstZ + z * outputStride[1];
for (int x=0; x<sizeX; ++x) {
auto dstOffset = x * dstStride[2];
dstY[dstOffset] = dstY[dstOffset] * srcY[x];
}
}
}
break;
case BinaryOpOperation_SUB:
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcF + z * dstStride[0];
auto dstZ = dstF + z * outputStride[0];
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + z * dstStride[1];
auto dstY = dstZ + z * outputStride[1];
for (int x=0; x<sizeX; ++x) {
auto dstOffset = x * dstStride[2];
auto D = dstY[dstOffset];
auto S = srcY[x];
dstY[dstOffset] = D - S;
}
}
}
break;
default:
break;
}
}
}
};
if (mLoop->parallel()) {
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
for (int iter=tId; iter < mLoop->loopNumber(); iter+=threadNumber) {
func(iter, tId);
}
}
MNN_CONCURRENCY_END();
} else {
for (int iter=0; iter < mLoop->loopNumber(); ++iter) {
func(iter, 0);
}
}
return NO_ERROR;
}