tools/converter/source/optimizer/merge/MergeHelpers.cpp (197 lines of code) (raw):

// // MergeHelpers.cpp // MNNConverter // // Created by MNN on b'2020/07/20'. // Copyright © 2018, Alibaba Group Holding Limited // #include <unordered_map> #include <vector> #include "MNN_generated.h" #include "MergeHelpers.hpp" using namespace MNN::Express; namespace MNN { namespace helpers { static MNN_DATA_FORMAT convertFormat(Express::Dimensionformat format) { switch (format) { case Express::NCHW: return MNN_DATA_FORMAT_NCHW; case Express::NHWC: return MNN_DATA_FORMAT_NHWC; case Express::NC4HW4: return MNN_DATA_FORMAT_NC4HW4; default: return MNN_DATA_FORMAT_UNKNOWN; } } bool IsConstant(EXPRP expr) { const Op* op = expr->get(); if ((op && op->type() == OpType_Const) || (!op && expr->inputType() == VARP::CONSTANT)) { return true; } return false; } bool IsBinaryOp(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_BinaryOp; } bool IsCast(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Cast; } bool IsConcat(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Concat; } bool IsReshape(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Reshape; } bool IsUnsqueeze(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Unsqueeze; } bool IsTranspose(EXPRP expr) { const Op* op = expr->get(); return op && (op->type() == OpType_Transpose || op->type() == OpType_Permute); } bool IsScatterNd(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_ScatterNd; } bool IsMatMul(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_MatMul; } bool IsSoftmax(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Softmax; } bool IsSelect(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Select; } bool IsGatherV2(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_GatherV2; } bool IsSlice(EXPRP expr) { const Op* op = expr->get(); return op && (op->type() == OpType_Slice || op->type() == OpType_StridedSlice || op->type() == OpType_SliceTf); } bool IsUnaryOp(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_UnaryOp; } bool IsLayerNorm(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_LayerNorm; } #define IS_BINARY_OP_TYPE(op_type) \ if (!IsBinaryOp(expr)) { \ return false; \ } \ int type = expr->get()->main_as_BinaryOp()->opType(); \ return type == op_type; #define IS_UNARY_OP_TYPE(op_type) \ if (!IsUnaryOp(expr)) { \ return false; \ } \ int type = expr->get()->main_as_UnaryOp()->opType(); \ return type == op_type; bool IsBinaryAdd(EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_ADD); } bool IsBinarySub(EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_SUB); } bool IsBinaryMul(EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_MUL); } bool IsBinaryRealDiv(EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_REALDIV); } bool IsBinarySquaredDifference(Express::EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_SquaredDifference); } bool IsUnarySquare(EXPRP expr) { IS_UNARY_OP_TYPE(UnaryOpOperation_SQUARE); } bool IsBinaryPow(EXPRP expr) { IS_BINARY_OP_TYPE(BinaryOpOperation_POW); } bool IsUnarySqrt(EXPRP expr) { IS_UNARY_OP_TYPE(UnaryOpOperation_SQRT); } bool IsUnaryRsqrt(EXPRP expr) { IS_UNARY_OP_TYPE(UnaryOpOperation_RSQRT); } bool IsUnaryNeg(EXPRP expr) { IS_UNARY_OP_TYPE(UnaryOpOperation_NEG); } #undef IS_BINARY_OP_TYPE #undef IS_UNARY_OP_TYPE bool IsReductionMean(EXPRP expr) { const Op* op = expr->get(); if (!op || op->type() != OpType_Reduction) { return false; } int type = op->main_as_ReductionParam()->operation(); return type == ReductionType_MEAN; } bool IsConvolution(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_Convolution; } bool IsExpandDims(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_ExpandDims; } bool IsBroadcastTo(EXPRP expr) { const Op* op = expr->get(); return op && op->type() == OpType_BroadcastTo; } EXPRP InputExpr(EXPRP expr, int input_index) { return expr->inputs().at(input_index)->expr().first; } EXPRP OutputExpr(EXPRP expr, int output_index) { return expr->outputs().at(output_index).lock(); } std::vector<VARP> OutputVars(EXPRP expr) { std::unordered_map<int, VARP> outputs; for (WeakEXPRP w : expr->outputs()) { EXPRP child = w.lock(); if (!child.get()) { continue; } for (VARP output : child->inputs()) { if (output.get() == nullptr) { continue; } int output_index = 0; EXPRP parent; std::tie(parent, output_index) = output->expr(); if (parent.get() == expr.get()) { outputs.emplace(output_index, output); } } } std::vector<VARP> v_outputs; for (const auto& it : outputs) { int index = 0; VARP output; std::tie(index, output) = it; if (!output.get()) { continue; } if (v_outputs.size() <= index) { v_outputs.resize(index + 1); } v_outputs[index] = output; } return std::move(v_outputs); } VARP ConvertLayout(VARP input, Dimensionformat dest_layout, Dimensionformat src_layout) { std::unique_ptr<OpT> convert(new OpT); convert->type = OpType_ConvertTensor; convert->main.type = OpParameter_TensorConvertInfo; convert->main.value = new TensorConvertInfoT; convert->main.AsTensorConvertInfo()->dest = convertFormat(dest_layout); convert->main.AsTensorConvertInfo()->source = convertFormat(src_layout); return (Variable::create(Expr::create(convert.get(), {input}))); } } // namespace helpers } // namespace MNN