ErrorCode GeometryComputerUtils::shapeComputeAndGeometryTransform()

in source/geometry/GeometryComputerUtils.cpp [142:440]


ErrorCode GeometryComputerUtils::shapeComputeAndGeometryTransform(
    const Runtime* cpuRuntime,
    FileLoader* external,
    std::vector<Schedule::OpCacheInfo>& infos,
    GeometryComputer::Context& geoContext,
    std::shared_ptr<Backend> backupBackend,
    Runtime::CompilerType compileType, 
    bool skipShapeCompute,
    bool permitCodegen) {
    bool openCache = geoContext.support(Interpreter::GeometryComputeMask::GEOMETRCOMPUTEMASK_OPENCACHE);
    /** Size Compute and compute Const Begin */
    GeometryComputer::Context ctx(Interpreter::GeometryComputeMask::GEOMETRCOMPUTEMASK_ALL, backupBackend);
    bool needRelease = geoContext.mNeedRelease;
    // Size Compute and compute Const
    for (int i=0; i<infos.size(); ++i) {
        auto& info = infos[i];
        auto& cmdBufferVir = info.executeBuffer;
        auto& tempBuffer = info.cacheBuffer;
        // TODO: Optimize
        for (auto t : info.outputs) {
            if (!TensorUtils::getDescribe(t)->isMutable) {
                continue;
            }
            auto des = TensorUtils::getDescribe(t);
            auto usage = des->usage;
            auto type = des->memoryType;
            MNN_ASSERT(type != Tensor::InsideDescribe::MEMORY_OUTSIDE);
            MNN_ASSERT(type != Tensor::InsideDescribe::MEMORY_HOST);
            if (TensorUtils::getDescribeOrigin(t)->mContent.use_count() > 1) {
                TensorUtils::getDescribeOrigin(t)->mContent.reset(new  Tensor::InsideDescribe::NativeInsideDescribe);
                t->buffer().dim = TensorUtils::getDescribe(t)->dims;
                TensorUtils::getDescribeOrigin(t)->setBackend(nullptr);
                TensorUtils::getDescribeOrigin(t)->mem = nullptr;
                TensorUtils::getDescribe(t)->usage = usage;
                info.computeCache.close();
            } else if (des->group == 0) {
                if (info.type != Schedule::CONSTANT && usage != Tensor::InsideDescribe::TRAINABLE) {
                    TensorUtils::getDescribeOrigin(t)->setBackend(nullptr);
                    // TODO: If output is static and length larger than new size, don't clear mem
                    TensorUtils::getDescribeOrigin(t)->mem = nullptr;
                }
            }
        }
        for (auto t : info.outputs) {
            TensorUtils::getDescribe(t)->stageMask &= (~Tensor::InsideDescribe::StageInfo::COMPUTE_SHAPE_STAGE);
        }
        bool compared = false;
        bool needCompute = !info.computeCache.match(info.inputs, compared);
        if (needCompute && compared) {
            // If not match, means the op's shape is mutable, close cache and don't compare
            info.computeCache.close(false);
        }
        if ((!skipShapeCompute) && needCompute) {
            auto res = SizeComputer::computeOutputSize(info.op, info.inputs, info.outputs);
            if (!res) {
                if (info.op->name() != nullptr) {
                    MNN_ERROR("Compute Shape Error for %s\n", info.op->name()->c_str());
                } else {
                    MNN_ERROR("Compute Shape Error for %d\n", info.op->type());
                }
                return COMPUTE_SIZE_ERROR;
            }
            // FIXME: Find better way to may compability for old model
            /**
             For Convolution of 2D / 3D Tensor(Dense / 1D Convolution)
             Because of old code, we will acces dim[2] / dim[3] to get width and height
             Set the lenght to 1 for compability
             */
            for (auto t : info.outputs) {
                TensorUtils::adjustTensorForCompability(t);
            }
            for (auto t: info.inputs) {
                TensorUtils::adjustTensorForCompability(t);
            }
            info.computeCache.insert(info.inputs);
            for (auto t : info.outputs) {
                TensorUtils::getDescribe(t)->rasterCommand.reset();
                TensorUtils::getDescribe(t)->stageMask |= Tensor::InsideDescribe::StageInfo::COMPUTE_SHAPE_STAGE;
                // The content may be computed by geometry computer, which will not make execution
                TensorUtils::getDescribe(t)->stageMask &= (~Tensor::InsideDescribe::StageInfo::CONTENT_NOT_CHANGE);
            }
        }
        info.computeCache.needComputeShape = needCompute;
        if (info.type != Schedule::CONSTANT) {
            continue;
        }
        if (!needCompute) {
            for (auto t : info.outputs) {
                TensorUtils::getDescribe(t)->stageMask |= Tensor::InsideDescribe::StageInfo::CONTENT_NOT_CHANGE;
            }
        }
        if (_hasZeroShapeOutput(info)) {
            continue;
        }
        // Skip geometry compute if no-needCompute
        if (needCompute) {
            cmdBufferVir.command.clear();
            cmdBufferVir.extras.clear();
            
            ctx.clear();
            auto geo = GeometryComputer::search(info.op->type(), Runtime::Compiler_Loop);
            {
                bool res = false;
                if (openCache) {
                    res = geo->onRecompute(info.op, info.inputs, info.outputs, geoContext, tempBuffer);
                }
                if (!res) {
                    tempBuffer.command.clear();
                    tempBuffer.extras.clear();
                    res = geo->onCompute(info.op, info.inputs, info.outputs, geoContext, tempBuffer);
                }
                if (!res) {
                    MNN_ERROR("Const Folder Error in geometry for %s\n", info.op->name()->c_str());
                    return NOT_SUPPORT;
                }
            }
            GeometryComputerUtils::makeRaster(tempBuffer, cmdBufferVir, ctx);
            for (auto t : info.outputs) {
                ctx.getRasterCacheCreateRecursive(t, cmdBufferVir);
                if (Tensor::InsideDescribe::MEMORY_VIRTUAL == TensorUtils::getDescribe(t)->memoryType) {
                    TensorUtils::getDescribe(t)->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
                }
            }
            for (auto& cp : cmdBufferVir.command) {
                auto& c = *cp;
                std::shared_ptr<BufferStorage> tmpStorge;
                if (nullptr == c.execution) {
                    auto opIter = info.executionCache.find(c.op);
                    if (opIter != info.executionCache.end()) {
                        c.execution = opIter->second;
                    } else {
                        auto exe = OpCommonUtils::createExecutionWithExternal(backupBackend.get(), c.inputs, c.outputs, c.op, external, tmpStorge);
                        c.execution.reset(exe);
                    }
                }
                auto exe = c.execution;
                if (nullptr == exe.get()) {
                    MNN_ERROR("Const Folder Error for %s\n", info.op->name()->c_str());
                    return NO_EXECUTION;
                }
                backupBackend->onResizeBegin();
                for (auto t : c.outputs) {
                    auto des = TensorUtils::getDescribeOrigin(t);
                    TensorUtils::setLinearLayout(t);
                    auto res = backupBackend->onAcquireBuffer(t, Backend::STATIC);
                    if (!res) {
                        return OUT_OF_MEMORY;
                    }
                    des->setBackend(backupBackend.get());
                }
                auto code = exe->onResize(c.inputs, c.outputs);
                if (NO_ERROR != code) {
                    return NOT_SUPPORT;
                }
                code = backupBackend->onResizeEnd();
                if (NO_ERROR != code) {
                    return NOT_SUPPORT;
                }
            }
        }
        for (auto& cp : cmdBufferVir.command) {
            auto& c = *cp;
            bool dirty = needCompute || c.op->type() == OpType_RandomNormal || c.op->type() == OpType_RandomUniform;
            if (!dirty) {
                for (auto t : c.inputs) {
                    auto des = TensorUtils::getDescribe(t);
                    if (!des->isMutable) {
                        continue;
                    }
                    if (des->group < 0) {
                        // From User Input, group = -1
                        dirty = true;
                        break;
                    }
                    if ((des->stageMask &                Tensor::InsideDescribe::StageInfo::CONTENT_NOT_CHANGE) == 0) {
                        dirty = true;
                        break;
                    }
                }
            }
            info.computeCache.needExecuteConst = dirty;
            if (dirty) {
                backupBackend->onExecuteBegin();
                if (cpuRuntime->pCurrentStatus != NO_ERROR) {
                    return (ErrorCode)cpuRuntime->pCurrentStatus;
                }
                auto code = cp->execution->onExecute(c.inputs, c.outputs);
                if (NO_ERROR != code) {
                    return NOT_SUPPORT;
                }
                backupBackend->onExecuteEnd();

                for (auto t : c.outputs) {
                    TensorUtils::getDescribe(t)->stageMask &= (~Tensor::InsideDescribe::StageInfo::CONTENT_NOT_CHANGE);
                }
            } else {
                for (auto t : c.outputs) {
                    TensorUtils::getDescribe(t)->stageMask |= Tensor::InsideDescribe::StageInfo::CONTENT_NOT_CHANGE;
                }
            }
        }
        if (needRelease) {
            cmdBufferVir.command.clear();
            cmdBufferVir.extras.clear();
            
            ctx.clear();
            for (auto index : info.releaseAbleInputs) {
                TensorUtils::getDescribeOrigin(info.inputs[index])->mem = nullptr;
            }
        }
    }

    /** Size Compute and compute Const End */

    /** Geometry Transform */
    for (int i=0; i<infos.size(); ++i) {
        auto& info = infos[i];
        auto& cmdBufferReal = info.executeBuffer;
        auto& tempBuffer = info.cacheBuffer;
        // TODO: Optimize
        if (info.type == Schedule::CONSTANT) {
            continue;
        }
        if ((!info.computeCache.needComputeShape) && (!tempBuffer.hasWrap)) {
            continue;
        }
        cmdBufferReal.command.clear();
        cmdBufferReal.extras.clear();
        if (_hasZeroShapeOutput(info)) {
            continue;
        }
        auto geo = GeometryComputer::search(info.op->type(), compileType);
        {
            bool res = false;
            if ((!tempBuffer.hasWrap) && openCache) {
                res = geo->onRecompute(info.op, info.inputs, info.outputs, geoContext, tempBuffer);
            }
            if (!res) {
                tempBuffer.command.clear();
                tempBuffer.extras.clear();
                res = geo->onCompute(info.op, info.inputs, info.outputs, geoContext, tempBuffer);
            }
            if (!res) {
                return NOT_SUPPORT;
            }
            tempBuffer.hasWrap = false;
            GeometryComputerUtils::makeRaster(tempBuffer, cmdBufferReal, geoContext);
            for (int v=0; v<info.outputs.size(); ++v) {
                auto t = info.outputs[v];
                auto des = TensorUtils::getDescribe(t);
                if (des->usage == Tensor::InsideDescribe::OUTPUT || des->usage == Tensor::InsideDescribe::TRAINABLE) {
                    // For output and trainable value, must directly compute the tensor
                    geoContext.getRasterCacheCreateRecursive(t, cmdBufferReal);
                    if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
                        des->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
                    }
                }
            }
        }
    }

    
#ifdef MNN_BUILD_CODEGEN
    if(permitCodegen) {
        #ifdef LOG_VERPOSE
        MNN_PRINT("infos : [\n");
        for (auto info : infos) {
            auto& cmds = info.executeBuffer.command;
            for (auto cmd : cmds) {
                MNN_PRINT("\t%s", EnumNameOpType(cmd->op->type()));
                if(cmd->op->type() == OpType_BinaryOp) {
                    MNN_PRINT(" %d ", cmd->op->main_as_BinaryOp()->opType());
                }
                if(cmd->op->type() == OpType_UnaryOp) {
                    MNN_PRINT(" %d ", cmd->op->main_as_UnaryOp()->opType());
                }
                MNN_PRINT("\n");
            }
        }
        MNN_PRINT("]\n");
        MNN_PRINT("==================== opFuse ====================\n");
        #endif

        opFuse(infos, geoContext.forwardType(), geoContext.precisionType());

        #ifdef LOG_VERPOSE
        MNN_PRINT("infos : [\n");
        for (auto info : infos) {
            auto& cmds = info.executeBuffer.command;
            for (auto cmd : cmds) {
                MNN_PRINT("\t%s\n", EnumNameOpType(cmd->op->type()));
            }
        }
        MNN_PRINT("]\n");
        #endif
    }
#endif
    return NO_ERROR;
}