virtual bool onExecute()

in tools/converter/source/optimizer/postconvert/RemoveInvalidCast.cpp [40:245]


    virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
        if (net->sourceType == MNN::NetSource_TENSORFLOW || net->sourceType == MNN::NetSource_TFLITE) {
            // The two framework has valid src type for cast, don't need treat
            return true;
        }
        if (net->sourceType == MNN::NetSource_CAFFE) {
            // For caffe has no invalid cast op
            return true;
        }
        bool needTreat = false;
        for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
            auto& op = *iter;
            if (op->type == MNN::OpType_Cast) {
                needTreat = true;
                break;
            }
        }
        if (!needTreat) {
            return true;
        }
        // Infer DataType for All Tensor
        std::vector<MNN::DataType> types(net->tensorName.size(), MNN::DataType_DT_INVALID);
        for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
            auto& op = *iter;
            switch (op->type) {
                // Float Op
                case MNN::OpType_PReLU:
                case MNN::OpType_Softmax:
                case MNN::OpType_Convolution:
                case MNN::OpType_ConvolutionDepthwise:
                case MNN::OpType_Convolution3D:
                case MNN::OpType_Deconvolution:
                case MNN::OpType_DeconvolutionDepthwise:
                case MNN::OpType_Interp:
                case MNN::OpType_LSTM:
                case MNN::OpType_LSTMBlockCell:
                case MNN::OpType_GridSample:
                case MNN::OpType_RNNSequenceGRU:
                case MNN::OpType_MatMul:
                    types[op->inputIndexes[0]] = MNN::DataType_DT_FLOAT;
                    if (op->outputIndexes.size() == 1) {
                        // 4 is integer matmul
                        types[op->outputIndexes[0]] = MNN::DataType_DT_FLOAT;
                    }
                    break;
                default:
                    break;
            }
        }
        for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
            auto& op = *iter;
            switch (op->type) {
                case MNN::OpType_Input:
                    types[op->outputIndexes[0]] = op->main.AsInput()->dtype;
                    break;
                case MNN::OpType_Cast:
                    types[op->outputIndexes[0]] = op->main.AsCastParam()->dstT;
                    break;
                case MNN::OpType_CastLike:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[1]];
                    break;
                case MNN::OpType_Const:
                case MNN::OpType_TrainableParam:
                    types[op->outputIndexes[0]] = op->main.AsBlob()->dataType;
                    break;
                case MNN::OpType_Fill:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[1]];
                    break;
                case MNN::OpType_Slice:
                case MNN::OpType_SliceTf:
                case MNN::OpType_Unpack:
                    for (auto v : op->outputIndexes) {
                        types[v] = types[op->inputIndexes[0]];
                    }
                    break;
                case MNN::OpType_GatherV2:
                case MNN::OpType_GatherND:
                case MNN::OpType_Reduction:
                case MNN::OpType_Range:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
                    break;
                case MNN::OpType_Shape:
                case MNN::OpType_Size:
                case MNN::OpType_Rank:
                case MNN::OpType_UnravelIndex:
                    types[op->outputIndexes[0]] = MNN::DataType_DT_INT32;
                    break;
                case MNN::OpType_Unique:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
                    for (int v=1; v<op->outputIndexes.size(); ++v) {
                        types[op->outputIndexes[v]] = MNN::DataType_DT_INT32;
                    }
                    break;
                case MNN::OpType_RandomUniform:
                    types[op->outputIndexes[0]] = op->main.AsRandomUniform()->type;
                    break;
                case MNN::OpType_ArgMax:
                    types[op->outputIndexes[0]] = MNN::DataType_DT_INT32;
                    break;
                case MNN::OpType_TopKV2:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
                    if (op->outputIndexes.size() > 1) {
                        types[op->outputIndexes[1]] = MNN::DataType_DT_INT32;
                    }
                    break;
                case MNN::OpType_ScatterNd:
                case MNN::OpType_Select:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[1]];
                    break;
                case MNN::OpType_OneHot:
                    types[op->outputIndexes[0]] = types[op->inputIndexes[2]];
                    break;
                case MNN::OpType_Extra:
                case MNN::OpType_Plugin:
                    break;
                case MNN::OpType_BinaryOp:
                {
                    if (outputBool(op->main.AsBinaryOp()->opType)) {
                        types[op->outputIndexes[0]] = DataType_DT_BOOL;
                    } else {
                        types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
                    }
                }
                    break;
                // Deform
                case MNN::OpType_Broastcast:
                case MNN::OpType_Concat:
                case MNN::OpType_Crop:
                case MNN::OpType_CropAndResize:
                case MNN::OpType_Col2Im:
                case MNN::OpType_DepthToSpace:
                case MNN::OpType_ExpandDims:
                case MNN::OpType_Flatten:
                case MNN::OpType_Interp:
                case MNN::OpType_Interp3D:
                case MNN::OpType_Im2Col:
                case MNN::OpType_Pack:
                case MNN::OpType_Padding:
                case MNN::OpType_Permute:
                case MNN::OpType_Reshape:
                case MNN::OpType_Resize:
                case MNN::OpType_StridedSlice:
                case MNN::OpType_SpaceToDepth:
                case MNN::OpType_Squeeze:
                case MNN::OpType_Transpose:
                case MNN::OpType_Unsqueeze:
                {
                    types[op->outputIndexes[0]] = types[op->inputIndexes[0]];
                }
                    break;
                default:
                    break;
            }
        }
        // Remove Useless Cast
        const MNN::NetT* const netPtr = net.get();
        for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
            auto& op          = *iter;
            if (op->type != MNN::OpType_Cast && op->type != MNN::OpType_CastLike) {
                iter++;
                continue;
            }
            if (types[op->inputIndexes[0]] == MNN::DataType_DT_INVALID) {
                iter++;
                continue;
            }
            if (types[op->inputIndexes[0]] != types[op->outputIndexes[0]]) {
                auto type = types[op->outputIndexes[0]];
                if (op->type == MNN::OpType_CastLike) {
                    if (type != MNN::DataType_DT_INVALID) {
                        // Turn Castlike to cast
                        op->type = MNN::OpType_Cast;
                        op->inputIndexes = {op->inputIndexes[0]};
                        op->main.Reset();
                        op->main.value = new CastParamT;
                        op->main.type = OpParameter_CastParam;
                        op->main.AsCastParam()->dstT = type;
                    }
                }
                iter++;
                continue;
            }
            if (std::find(net->outputName.begin(), net->outputName.end(), net->tensorName[op->outputIndexes[0]]) != net->outputName.end()) {
                iter++;
                continue;
            }
            // Find the next op
            if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
                iter = net->oplists.erase(iter);
                continue;
            }

            auto originInput  = op->inputIndexes[0];
            auto originOutputs = op->outputIndexes;
            for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
                auto& subOp = *subIter;
                for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
                    if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
                        subOp->inputIndexes[v] = originInput;
                    }
                }
            }
            iter = net->oplists.erase(iter);
        }
        return true;
    }