bool fuseConstIntoSubgraph()

in tools/converter/source/optimizer/PostConverter.cpp [361:586]


bool fuseConstIntoSubgraph(MNN::NetT* net, const std::vector<MNN::SubGraphProtoT*>& subgraphs) {
    if (subgraphs.empty()) {
        return false;
    }
    // Create Map for subGraphs
    // Key, protot, refcount
    std::map<std::string, std::pair<MNN::SubGraphProtoT*, int>> subGraphMaps;
    std::set<MNN::SubGraphProtoT*> modifiedSubGraph;
    for (auto s : subgraphs) {
        subGraphMaps.insert(std::make_pair(s->name, std::make_pair(s, 0)));
    }
    for (int i = 0; i < net->oplists.size(); ++i) {
        auto& op = net->oplists[i];
        if (op->type == MNN::OpType_While) {
            auto param = op->main.AsWhileParam();
            subGraphMaps[param->body_graph].second++;
            subGraphMaps[param->cond_graph].second++;
            continue;
        }
        if (op->type == MNN::OpType_If) {
            auto param = op->main.AsIfParam();
            subGraphMaps[param->else_graph].second++;
            subGraphMaps[param->then_graph].second++;
            continue;
        }
    }

    // Try Merge Const into subgraph
    // Search all const op
    std::vector<int> constOpIndexes(net->tensorName.size(), -1);
    for (int i = 0; i < net->oplists.size(); ++i) {
        auto& op = net->oplists[i];
        if (op->type == MNN::OpType_Const) {
            constOpIndexes[op->outputIndexes[0]] = i;
        }
    }

    // Try Merge for while
    std::set<int> removeConstOpIndexes;
    for (int opIndex = 0; opIndex < net->oplists.size(); ++opIndex) {
        auto& op = net->oplists[opIndex];
        if (op->type != MNN::OpType_While) {
            continue;
        }
        auto param = op->main.AsWhileParam();
        if (param->cond_graph.empty()) {
            // If cond_graph is empty, it come from onnx's loop
            // TODO: Support Loop from onnx
            continue;
        }
        auto body  = subGraphMaps[param->body_graph];
        auto cond  = subGraphMaps[param->cond_graph];
        // Don't support for shared subgrah's optimize
        if (body.second > 1 || cond.second > 1) {
            continue;
        }
        MNN_ASSERT(op->inputIndexes.size() == param->aliases_inputs.size());

        // Merge into subgraph
        std::set<int> removeInputs;
        std::set<int> bodyInputRemove;
        std::set<int> condInputRemove;
        auto mergeToSubGraph = [](MNN::SubGraphProtoT* subGraph, std::set<int>& inputRemove, const MNN::OpT* constOp,
                                  const std::string& inputName) {
            // Merge Const Index to Body
            for (auto& inputIndex : subGraph->inputs) {
                if (subGraph->tensors[inputIndex] == inputName) {
                    inputRemove.insert(inputIndex);
                    for (int v = 0; v < subGraph->nodes.size(); ++v) {
                        auto& subOp = subGraph->nodes[v];
                        if (subOp->type != MNN::OpType_Input) {
                            continue;
                        }
                        if (subOp->outputIndexes[0] == inputIndex) {
                            auto src              = constOp->main.AsBlob();
                            subOp->type           = MNN::OpType_Const;
                            subOp->main.type      = MNN::OpParameter_Blob;
                            subOp->main.value     = new MNN::BlobT;
                            *subOp->main.AsBlob() = *src;
                            break;
                        }
                    }
                    break;
                }
            }
            return true;
        };
        for (int subI = 0; subI < op->inputIndexes.size(); ++subI) {
            auto index      = op->inputIndexes[subI];
            auto constIndex = constOpIndexes[index];
            if (constIndex < 0) {
                continue;
            }
            // Don't support for graph shared input
            if (param->aliases_inputs[subI]->data.size() != 1) {
                continue;
            }
            auto inputName = param->aliases_inputs[subI]->data[0];
            // Don't support for const init and update next
            bool isUpdate = false;
            for (auto& update : param->aliases_updates) {
                for (auto updateName : update->data) {
                    if (updateName == inputName) {
                        isUpdate = true;
                        break;
                    }
                }
                if (isUpdate) {
                    break;
                }
            }
            if (isUpdate) {
                continue;
            }
            // Count Refcount for const tensor
            int refCount = 0;
            for (int sub = constIndex + 1; sub < net->oplists.size(); ++sub) {
                auto& subOp = net->oplists[sub];
                for (auto subIndex : subOp->inputIndexes) {
                    if (subIndex == index) {
                        refCount++;
                        break;
                    }
                }
            }
            if (refCount > 1) {
                // The const input is shared with other op
                continue;
            }
            auto& constOp = net->oplists[constIndex];
            //FUNC_PRINT_ALL(constOp->name.c_str(), s);
            MNN_ASSERT(constOp->main.type == MNN::OpParameter_Blob);

            removeConstOpIndexes.insert(constIndex);
            mergeToSubGraph(body.first, bodyInputRemove, constOp.get(), inputName);
            mergeToSubGraph(cond.first, condInputRemove, constOp.get(), inputName);
            removeInputs.insert(subI);

            modifiedSubGraph.insert(body.first);
            modifiedSubGraph.insert(cond.first);

            // Release no needed Const Memory
            constOp->main.Reset();
        }
        auto removeSubGraphInputs = [](MNN::SubGraphProtoT* subGraph, const std::set<int>& inputRemove) {
            auto originInput = std::move(subGraph->inputs);
            subGraph->inputs.clear();
            for (auto index : originInput) {
                if (inputRemove.find(index) == inputRemove.end()) {
                    subGraph->inputs.emplace_back(index);
                }
            }
        };
        removeSubGraphInputs(body.first, bodyInputRemove);
        removeSubGraphInputs(cond.first, condInputRemove);

        // Remove no use input for while op
        auto originIndexes = std::move(op->inputIndexes);
        auto aliInputs     = std::move(param->aliases_inputs);
        for (int subI = 0; subI < originIndexes.size(); ++subI) {
            if (removeInputs.find(subI) == removeInputs.end()) {
                op->inputIndexes.emplace_back(originIndexes[subI]);
                param->aliases_inputs.emplace_back(std::move(aliInputs[subI]));
            }
        }
    }
    if (removeConstOpIndexes.empty()) {
        return false;
    }
    auto originOpLists = std::move(net->oplists);
    for (int i = 0; i < originOpLists.size(); ++i) {
        if (removeConstOpIndexes.find(i) == removeConstOpIndexes.end()) {
            net->oplists.emplace_back(std::move(originOpLists[i]));
        }
    }
    // Try Optimize Subgraph for more const op get
    auto* ctx = Global<OptimizeContext>::Get();
    std::unordered_map<std::string, VARP> empty;
    for (auto mutable_subgraph : modifiedSubGraph) {
        std::unique_ptr<MNN::NetT> subnet(new MNN::NetT);
        subnet->oplists    = std::move(mutable_subgraph->nodes);
        subnet->tensorName = std::move(mutable_subgraph->tensors);
        subnet->sourceType = ctx->source;
        std::vector<std::string> inputNames;
        std::vector<std::string> outputNames;
        for (auto v: mutable_subgraph->inputs) {
            inputNames.emplace_back(subnet->tensorName[v]);
        }
        for (auto v: mutable_subgraph->outputs) {
            outputNames.emplace_back(subnet->tensorName[v]);
        }
#ifdef MNN_POST_CONVERTER_DEBUG
        for (auto& v : outputNames) {
            FUNC_PRINT_ALL(v.c_str(), s);
        }
        FUNC_PRINT_ALL(mutable_subgraph->name.c_str(), s);
#endif
        subnet->outputName = outputNames;

        std::unique_ptr<MNN::NetT> new_subnet = optimizeNetImpl(subnet, empty);
        mutable_subgraph->nodes               = std::move(subnet->oplists);

        MNN::SubGraphProtoT* new_subgraph = mutable_subgraph;
        for (int i = 0; i < inputNames.size(); ++i) {
            auto& name = inputNames[i];
            for (int v = 0; v < new_subnet->tensorName.size(); ++v) {
                if (new_subnet->tensorName[v] == name) {
                    mutable_subgraph->inputs[i] = v;
                    break;
                }
            }
        }
        for (int i = 0; i < outputNames.size(); ++i) {
            auto& name = outputNames[i];
            for (int v = 0; v < new_subnet->tensorName.size(); ++v) {
                if (new_subnet->tensorName[v] == name) {
                    mutable_subgraph->outputs[i] = v;
                    break;
                }
            }
        }
        mutable_subgraph->nodes   = std::move(new_subnet->oplists);
        mutable_subgraph->tensors = std::move(new_subnet->tensorName);
    }
    return true;
}