virtual ErrorCode onExecute()

in source/backend/cpu/CPURaster.cpp [818:1130]


    virtual ErrorCode onExecute(const std::vector<Tensor *> &originInputs, const std::vector<Tensor *> &originOutputs) override {
        auto cpubackend = static_cast<CPUBackend*>(backend());
        auto precision = cpubackend->precisionMode();
        auto threadNumber = cpubackend->threadNumber();
        if (mLoop->initCommand() != nullptr) {
            for (int i=0; i<mLoop->initCommand()->size(); ++i) {
                auto cmd = mLoop->initCommand()->GetAs<RegionCommand>(i);
                if (cmd->op() == nullptr) {
                    auto output = mStack[cmd->indexes()->data()[0]];
                    ::memset(output->host<void>(), 0, cpubackend->getTensorSize(output) * cpubackend->functions()->bytes);
                } else {
                    Tensor::InsideDescribe::Region reg;
                    auto srcView = cmd->view()->GetAs<View>(1);
                    auto dstView = cmd->view()->GetAs<View>(0);
                    ::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
                    ::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
                    ::memcpy(reg.dst.stride, dstView->stride()->data(), 3 * sizeof(int32_t));
                    auto input = mStack[cmd->indexes()->data()[1]];
                    auto inputSize = input->elementSize();
                    auto output = mStack[cmd->indexes()->data()[0]];
                    auto bytes = input->getType().bytes();
                    if (halide_type_float == input->getType().code) {
                        bytes = cpubackend->functions()->bytes;
                    }
                    _blit(reg, bytes, input->host<uint8_t>(), output->host<uint8_t>(), false);
                }

            }
        }
        if (1 == mLoop->commands()->size()) {
            auto cmd = mLoop->commands()->GetAs<RegionCommand>(0);
            auto op = cmd->op();
            if (OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0) {
                // For Gather / Single Unary
                auto index0 = cmd->iterIndexes()->data()[0];
                auto index1 = cmd->iterIndexes()->data()[1];
                int32_t iter = 0;
                int32_t* iter0 = &iter;
                int32_t* iter1 = &iter;
                int32_t iter0Stride = 0;
                int32_t iter1Stride = 0;
                if (index0 >= 0) {
                    iter0 = originInputs[index0]->host<int32_t>();
                    iter0Stride = 1;
                }
                if (index1 >= 0) {
                    iter1 = originInputs[index1]->host<int32_t>();
                    iter1Stride = 1;
                }
                Tensor::InsideDescribe::Region reg;
                auto srcView = cmd->view()->GetAs<View>(1);
                auto dstView = cmd->view()->GetAs<View>(0);
                ::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
                ::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
                ::memcpy(reg.dst.stride, dstView->stride()->data(), 3 * sizeof(int32_t));
                auto input = mStack[cmd->indexes()->data()[1]];
                auto inputSize = input->usize() / input->buffer().type.bytes();
                auto output = mStack[cmd->indexes()->data()[0]];
                auto outputSize = output->usize() / output->buffer().type.bytes();
                auto bytes = input->getType().bytes();
                if (halide_type_float == input->getType().code) {
                    bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
                }
                auto step0 = cmd->steps()->data()[0];
                auto step1 = cmd->steps()->data()[1];
                auto loopNumber = mLoop->loopNumber();
                for (; iter<loopNumber; ++iter) {
                    auto srcIter = *(iter1 + iter1Stride * iter);
                    auto dstIter = *(iter0 + iter0Stride * iter);
                    auto srcOffset = srcIter * step1 + srcView->offset();
                    auto dstOffset = dstIter * step0 + dstView->offset();
                    if (dstOffset >= 0 && dstOffset < outputSize) {
                        if (srcOffset >= 0 && srcOffset < inputSize) {
                            _blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, false);
                        } else {
                            _zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset);
                        }
                    }
                }
                return NO_ERROR;
            }
        }
        auto bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
        auto func = [&](int iter, int tId) {
            int fuseOutputStride[3];
            const int32_t* outputStride = nullptr;
            auto fuseBuffer = mFuseBuffer + mMaxFuseBufferSize * tId;
            for (int index=0; index<mLoop->commands()->size(); ++index) {
                auto cmd = mLoop->commands()->GetAs<RegionCommand>(index);
                auto blit = _selectUnitProc(bytes, cmd->view()->GetAs<View>(1)->stride()->data()[2], 1);
                auto op = cmd->op();
                int iterIndexsize = cmd->iterIndexes()->size();
                
                if (cmd->fuse() >= 0) {
                    outputStride = fuseOutputStride;
                    auto cmdSize = cmd->size()->data();
                    fuseOutputStride[0] = cmdSize[1] * cmdSize[2];
                    fuseOutputStride[1] = cmdSize[2];
                    fuseOutputStride[2] = 1;
                } else {
                    // Loop Op's command's first index must be output
                    outputStride = cmd->view()->GetAs<View>(0)->stride()->data();
                }
                halide_type_t inputType;
                for (int v=0; v<iterIndexsize; ++v) {
                    auto tensorIndex = cmd->indexes()->data()[v];
                    auto tensor = mStack[tensorIndex];
                    auto iterIndex = cmd->iterIndexes()->data()[v];
                    auto offset = iter;
                    if (1 == v) {
                        inputType = tensor->getType();
                    }
                    if (iterIndex >= 0) {
                        offset = mStack[iterIndex]->host<int32_t>()[iter];
                    }
                    auto view = cmd->view()->GetAs<View>(v);
                    offset = offset * cmd->steps()->data()[v] + view->offset();
                    mContainer[tId].stackPtr[tensorIndex] = tensor->host<uint8_t>() + offset * bytes;
                    MNN_ASSERT(nullptr != tensor->host<uint8_t>());
                }
                auto dstOrigin = (uint8_t*)mContainer[tId].stackPtr[cmd->indexes()->data()[0]];
                auto dst = dstOrigin;
                if (cmd->fuse() >= 0) {
                    dst = fuseBuffer.ptr();
                }
                do {
                    if (OpType_UnaryOp == op->type()) {
                        auto src = (uint8_t*)mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
                        if (nullptr == op->main()) {
                            // Copy
                            Tensor::InsideDescribe::Region reg;
                            auto srcView = cmd->view()->GetAs<View>(1);
                            auto dstView = cmd->view()->GetAs<View>(0);
                            ::memcpy(reg.size, cmd->size()->data(), 3 * sizeof(int32_t));
                            ::memcpy(reg.src.stride, srcView->stride()->data(), 3 * sizeof(int32_t));
                            ::memcpy(reg.dst.stride, outputStride, 3 * sizeof(int32_t));
                            auto step0 = cmd->steps()->data()[0];
                            auto step1 = cmd->steps()->data()[1];
                            auto loopNumber = mLoop->loopNumber();
                            _blit(reg, bytes, (const uint8_t*)src, (uint8_t*)dst, false);
                            break;
                        }
                        auto proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectUnaryFunctionForFloat(op->main_as_UnaryOp()->opType(), static_cast<CPUBackend*>(backend())->precisionMode());
                        auto lastS = cmd->size()->data()[2];
                        if (lastS == 1 || cmd->view()->GetAs<View>(1)->stride()->data()[2] == 1) {
                            for (int z=0; z<cmd->size()->data()[0]; ++z) {
                                auto srcZ = src + z * cmd->view()->GetAs<View>(1)->stride()->data()[0] * bytes;
                                auto dstZ = dst + z * outputStride[0] * bytes;
                                for (int y=0; y<cmd->size()->data()[1]; ++y) {
                                    auto srcY = srcZ + y * cmd->view()->GetAs<View>(1)->stride()->data()[1] * bytes;
                                    auto dstY = dstZ + y * outputStride[1] * bytes;
                                    proc(dstY, srcY, lastS);
                                }
                            }
                        } else {
                            // Blit to cache
                            auto srcCache = mCacheBuffer.ptr() + mMaxCacheSize * tId;
                            for (int z=0; z<cmd->size()->data()[0]; ++z) {
                                auto srcZ = src + z * cmd->view()->GetAs<View>(1)->stride()->data()[0] * bytes;
                                auto dstZ = dst + z * outputStride[0] * bytes;
                                for (int y=0; y<cmd->size()->data()[1]; ++y) {
                                    auto srcY = srcZ + y * cmd->view()->GetAs<View>(1)->stride()->data()[1] * bytes;
                                    auto dstY = dstZ + y * outputStride[1] * bytes;
                                    blit(srcCache, srcY, lastS, cmd->view()->GetAs<View>(1)->stride()->data()[2], 1);
                                    proc(dstY, srcCache, lastS);
                                }
                            }
                        }
                        continue;
                    }
                    if (OpType_MatMul == op->type()) {
                        // TODO: Don't support fuse for matmul currently
                        const float* APtr = nullptr;
                        const float* BPtr = nullptr;
                        const float* BiasPtr = nullptr;
                        float* CPtr = (float*)dst;
                        auto exe = static_cast<CPUMatMul*>(mContainer[tId].exe[index].get());
                        APtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
                        BPtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[2]];
                        if (iterIndexsize > 3) {
                            BiasPtr = (const float*)mContainer[tId].stackPtr[cmd->indexes()->data()[3]];
                        }
                        exe->execute(APtr, BPtr, CPtr, BiasPtr);
                        break;
                    }
                    if (OpType_BinaryOp == op->type()) {
                        auto src0 = mContainer[tId].stackPtr[cmd->indexes()->data()[1]];
                        MNNBinaryExecute proc;
                        if (inputType.code == halide_type_float) {
                            proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(op->main_as_BinaryOp()->opType());
                        } else {
                            MNN_ASSERT(inputType.code == halide_type_int);
                            proc = CPUBinary::selectForInt(op->main_as_BinaryOp()->opType());
                        }
                        auto lastS = cmd->size()->data()[2];
                        auto stride0 = outputStride;
                        auto stride1 = cmd->view()->GetAs<View>(1)->stride()->data();
                        MNN_ASSERT(stride0[2] == 1);
                        auto src1 = mContainer[tId].stackPtr[cmd->indexes()->data()[2]];
                        auto stride2 = cmd->view()->GetAs<View>(2)->stride()->data();
                        auto blit1   = _selectUnitProc(bytes, stride1[2], 1);
                        auto blit2   = _selectUnitProc(bytes, stride2[2], 1);
                        if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) {
                            for (int z=0; z<cmd->size()->data()[0]; ++z) {
                                auto src0Z = src0 + z * stride1[0] * bytes;
                                auto src1Z = src1 + z * stride2[0] * bytes;
                                auto dstZ = dst + z * stride0[0] * bytes;
                                for (int y=0; y<cmd->size()->data()[1]; ++y) {
                                    auto src0Y = src0Z + y * stride1[1] * bytes;
                                    auto src1Y = src1Z + y * stride2[1] * bytes;
                                    auto dstY = dstZ + y * stride0[1] * bytes;
                                    proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1);
                                }
                            }
                        } else {
                            auto cache0 = mCacheBuffer.ptr() + mMaxCacheSize * tId;
                            auto cache1 = cache0 + cmd->size()->data()[2] * bytes;
                            for (int z=0; z<cmd->size()->data()[0]; ++z) {
                                auto src0Z = src0 + z * stride1[0] * bytes;
                                auto src1Z = src1 + z * stride2[0] * bytes;
                                auto dstZ = dst + z * stride0[0] * bytes;
                                for (int y=0; y<cmd->size()->data()[1]; ++y) {
                                    auto src0Y = src0Z + y * stride1[1] * bytes;
                                    auto src1Y = src1Z + y * stride2[1] * bytes;
                                    auto dstY = dstZ + y * stride0[1] * bytes;
                                    blit1(cache0, src0Y, cmd->size()->data()[2], stride1[2], 1);
                                    blit2(cache1, src1Y, cmd->size()->data()[2], stride2[2], 1);
                                    proc(dstY, cache0, cache1, cmd->size()->data()[2], -1);
                                }
                            }
                        }
                        break;
                    }
                } while(false);
                if (dst != dstOrigin) {
                    MNN_ASSERT(bytes == 4);
                    // Currently only support add and float32
                    auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
                    auto srcF = (const float*)dst;
                    auto dstF = (float*)dstOrigin;
                    int sizeZ = cmd->size()->data()[0];
                    int sizeY = cmd->size()->data()[1];
                    int sizeX = cmd->size()->data()[2];
                    if (cmd->op()->type() == OpType_MatMul) {
                        auto proc = static_cast<CPUBackend*>(backend())->functions()->MNNSelectBinaryFunctionForFloat(cmd->fuse());
                        proc(dstF, dstF, srcF, sizeZ * sizeX, -1);
                        continue;
                    }
                    switch (cmd->fuse()) {
                        case BinaryOpOperation_ADD:
                            for (int z=0; z<sizeZ; ++z) {
                                auto srcZ = srcF + z * outputStride[0];
                                auto dstZ = dstF + z * dstStride[0];
                                for (int y=0; y<sizeY; ++y) {
                                    auto srcY = srcZ + y * outputStride[1];
                                    auto dstY = dstZ + y * dstStride[1];
                                    for (int x=0; x<sizeX; ++x) {
                                        auto dstOffset = x * dstStride[2];
                                        dstY[dstOffset] = dstY[dstOffset] + srcY[x];
                                    }
                                }
                            }
                            break;
                        case BinaryOpOperation_MUL:
                            for (int z=0; z<sizeZ; ++z) {
                                auto srcZ = srcF + z * dstStride[0];
                                auto dstZ = dstF + z * outputStride[0];
                                for (int y=0; y<sizeY; ++y) {
                                    auto srcY = srcZ + z * dstStride[1];
                                    auto dstY = dstZ + z * outputStride[1];
                                    for (int x=0; x<sizeX; ++x) {
                                        auto dstOffset = x * dstStride[2];
                                        dstY[dstOffset] = dstY[dstOffset] * srcY[x];
                                    }
                                }
                            }
                            break;
                        case BinaryOpOperation_SUB:
                            for (int z=0; z<sizeZ; ++z) {
                                auto srcZ = srcF + z * dstStride[0];
                                auto dstZ = dstF + z * outputStride[0];
                                for (int y=0; y<sizeY; ++y) {
                                    auto srcY = srcZ + z * dstStride[1];
                                    auto dstY = dstZ + z * outputStride[1];
                                    for (int x=0; x<sizeX; ++x) {
                                        auto dstOffset = x * dstStride[2];
                                        auto D = dstY[dstOffset];
                                        auto S = srcY[x];
                                        dstY[dstOffset] = D - S;
                                    }
                                }
                            }
                            break;
                        default:
                            break;
                    }
                }
            }
        };
        if (mLoop->parallel()) {
            MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
                for (int iter=tId; iter < mLoop->loopNumber(); iter+=threadNumber) {
                    func(iter, tId);
                }
            }
            MNN_CONCURRENCY_END();
        } else {
            for (int iter=0; iter < mLoop->loopNumber(); ++iter) {
                func(iter, 0);
            }
        }
        return NO_ERROR;
    }