tools/converter/source/optimizer/PostTreatUtils.cpp (107 lines of code) (raw):
//
// PostTreatUtils.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "PostTreatUtils.hpp"
#include "OpCount.hpp"
#include <mutex>
#include <set>
using namespace MNN;
template <typename T>
bool inVector(const std::vector<T>& vec, const T& val) {
return std::find(vec.begin(), vec.end(), val) != vec.end();
}
std::map<std::string, std::shared_ptr<PostConverter>>* PostConverter::getConvertMap() {
static std::once_flag of;
static std::map<std::string, std::shared_ptr<PostConverter>>* gConverter;
std::call_once(of, [&]() {
gConverter = new std::map<std::string, std::shared_ptr<PostConverter>>;
auto count = MNN::OpCount::get();
count->insertOp("TF", "Dropout");
count->insertOp("TF", "NoOp");
count->insertOp("TF", "Print");
count->insertOp("CAFFE", "Dropout");
count->insertOp("CAFFE", "Split");
auto unuseExtraOpType = std::vector<std::string>({"Identity", "IdentityN", "NoOp", "Assign", "Print", "Assert", "StopGradient", "Enter", "NextIteration"});
for (auto& s : unuseExtraOpType) {
count->insertOp("TF", s);
}
std::set<std::string> controlOps{"Merge", "Switch", "LoopCond", "Enter", "Exit", "NextIteration"};
for (auto& s : controlOps) {
count->insertOp("TF", s);
}
count->insertOp("ONNX", "Identity");
});
return gConverter;
}
PostConverter* PostConverter::get(std::string key) {
auto gConverter = getConvertMap();
if (gConverter->find(key) != gConverter->end()) {
return gConverter->at(key).get();
}
return nullptr;
}
void PostConverter::add(std::shared_ptr<PostConverter> converter, std::string key) {
auto gConverter = getConvertMap();
gConverter->insert(std::make_pair(key, converter));
}
bool PostTreatUtils::_isSingleInputOutput(const MNN::OpT* op) {
if (op->inputIndexes.size() != 1 || op->outputIndexes.size() != 1) {
return false;
}
return true;
}
MNN::OpT* PostTreatUtils::_findOpByOutputIndex(int outputIndex, const NetT* net) {
for (auto& op : net->oplists) {
if (inVector(op->outputIndexes, outputIndex)) {
return op.get();
}
}
return nullptr;
}
std::vector<MNN::OpT*> PostTreatUtils::_findOpByInputIndex(int inputIndex, const NetT* net) {
std::vector<MNN::OpT*> ops;
for (auto& op : net->oplists) {
if (inVector(op->inputIndexes, inputIndex)) {
ops.push_back(op.get());
}
}
// check whether the next op is in_place op
const int opsSize = ops.size();
if (opsSize > 1) {
auto realNextOp = ops[0];
if (inVector(realNextOp->outputIndexes, inputIndex)) {
ops.clear();
ops.push_back(realNextOp);
}
}
return ops;
}
int PostTreatUtils::_getOpDecestorCount(MNN::OpT* op, const MNN::NetT* mNet) {
int decestorCount = 0;
for (auto& otherOp : mNet->oplists) {
if (otherOp.get() != op) {
for (auto inputIndex : otherOp->inputIndexes) {
if (inVector(op->outputIndexes, inputIndex)) {
decestorCount++;
break; // one decestor just count one.
}
}
}
}
return decestorCount;
}
void PostTreatUtils::_removeOpInNet(MNN::OpT* op, MNN::NetT* net) {
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
if (iter->get() == op) {
// LOG(INFO) << "remove op: " << op->name;
net->oplists.erase(iter);
break;
}
}
}
bool PostTreatUtils::_replace(std::vector<int>& indexes, int freshIndex, int oldIndex) {
auto iter = indexes.begin();
while (iter != indexes.end()) {
if (*iter == oldIndex) {
*iter = freshIndex;
return true;
}
++iter;
}
return false;
}