virtual bool onExecute()

in tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp [319:602]


    virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
        auto& mNet = net;
        if (mNet->sourceType == MNN::NetSource_CAFFE) {
            return true;
        }
        auto* ctx = Global<MNN::Express::OptimizeContext>::Get();

        auto originTensorType = MNN::MNN_DATA_FORMAT_NHWC;
        if (mNet->sourceType == MNN::NetSource_ONNX || mNet->sourceType == MNN::NetSource_TORCH) {
            originTensorType = MNN::MNN_DATA_FORMAT_NCHW;
        }
        for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
            auto op = iter->get();
            if (OpParameter_Blob == op->main.type) {
                if (op->main.AsBlob()->dataFormat != MNN_DATA_FORMAT_NC4HW4) {
                    op->main.AsBlob()->dataFormat = originTensorType;
                }
            }
            if (OpParameter_Reshape == op->main.type) {
                op->main.AsReshape()->dimType = originTensorType;
            }
        }

        auto config = Global<modelConfig>::Get();
        auto version = config->targetVersion;
        // Compute All Tensor's format
        std::vector<MNN_DATA_FORMAT> tensorFormats(net->tensorName.size());
        std::vector<bool> readyMask(net->oplists.size());
        std::fill(tensorFormats.begin(), tensorFormats.end(), MNN_DATA_FORMAT_UNKNOWN);
        std::fill(readyMask.begin(), readyMask.end(), false);
        bool hasChange = false;
        bool complete = false;
        // Record Const Op Index
        std::vector<int32_t> constTensorIndexs;
        do {
            complete = true;
            hasChange = false;
            for (int i=0; i<readyMask.size(); ++i) {
                if (readyMask[i]) {
                    continue;
                }
                auto op = net->oplists[i].get();
                readyMask[i] = _computeTensorFormat(tensorFormats, constTensorIndexs, op, originTensorType, config->keepInputFormat, false);
                if (readyMask[i]) {
                    hasChange = true;
                } else {
                    complete = false;
                }
            }
        } while (hasChange);

        // Has can't determine one, force compability op use originFormat
        if (!complete) {
            for (int i=0; i<readyMask.size(); ++i) {
                if (readyMask[i]) {
                    continue;
                }
                auto op = net->oplists[i].get();
                readyMask[i] = _computeTensorFormat(tensorFormats, constTensorIndexs, op, originTensorType, config->keepInputFormat, true);
                MNN_ASSERT(readyMask[i] == true);
            }
        }
        // Insert Extra Converter
        std::map<int, int> convertMap;
        if (config->keepInputFormat) {
            // Change Output
            auto& outputs = mNet->outputName;
            std::vector<std::unique_ptr<MNN::OpT>> extraOp;
            for (auto& op : mNet->oplists) {
                for (int idx : op->outputIndexes) {
                    for (int j = 0; j < outputs.size(); j++) {
                        if (mNet->tensorName[idx] == outputs[j]) {
                            auto outputFormat = tensorFormats[idx];
                            if (outputFormat == MNN_DATA_FORMAT_NC4HW4) {
                                auto newOutputName = outputs[j] + "__before_tr";
                                mNet->tensorName[idx] = newOutputName;
                                // Append a convert op
                                MNN::OpT* transformOp = new MNN::OpT;
                                MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT;
                                tc->source                  = outputFormat;
                                tc->dest                    = originTensorType;
                                transformOp->main.type      = MNN::OpParameter_TensorConvertInfo;
                                transformOp->main.value     = tc;
                                transformOp->name           = newOutputName;
                                transformOp->inputIndexes.push_back(idx);
                                int newOutputIndex = (int)mNet->tensorName.size();
                                transformOp->outputIndexes.push_back(newOutputIndex);
                                tensorFormats.push_back(originTensorType);
                                mNet->tensorName.push_back(outputs[j]);
                                transformOp->type   = MNN::OpType_ConvertTensor;
                                extraOp.emplace_back(transformOp);
                            }
                        }
                    }
                }
            }
            for (auto&& op : extraOp) {
                mNet->oplists.emplace_back(std::move(op));
            }
        } else {
            // Change Input
            for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
                auto& op         = *iter;
                if (OpType_Input == op->type) {
                    auto originInputFormat = op->main.AsInput()->dformat;
                    op->main.AsInput()->dformat = tensorFormats[op->outputIndexes[0]];
                    if (originInputFormat == MNN_DATA_FORMAT_NHWC && op->main.AsInput()->dformat == MNN_DATA_FORMAT_NC4HW4 && op->main.AsInput()->dims.size() == 4 && ctx->first_run) {
                        int n = op->main.AsInput()->dims[0];
                        int h = op->main.AsInput()->dims[1];
                        int w = op->main.AsInput()->dims[2];
                        int c = op->main.AsInput()->dims[3];
                        op->main.AsInput()->dims = {n, c, h, w};
                    }
                }
            }
        }
        if (originTensorType == MNN_DATA_FORMAT_NHWC) {
            for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) {
                auto op = iter->get();
                // Insert Pretreat Op if needed
                if (op->type == OpType_Padding && tensorFormats[op->outputIndexes[0]] == MNN_DATA_FORMAT_NC4HW4 && ctx->first_run) {
                    const int padValueIndex = op->inputIndexes[1];
                    auto padValueOp         = PostTreatUtils::_findOpByOutputIndex(padValueIndex, mNet.get());
                    // Add Gather op for padding, turn nhwc -> nchw
                    std::unique_ptr<OpT> gatherIndex(new OpT);
                    gatherIndex->outputIndexes = {(int)mNet->tensorName.size()};
                    gatherIndex->type          = OpType_Const;
                    gatherIndex->name          = op->name + "_Gather_Index";
                    mNet->tensorName.emplace_back(gatherIndex->name);
                    tensorFormats.push_back(originTensorType);
                    gatherIndex->main.type                 = OpParameter_Blob;
                    gatherIndex->main.value                = new BlobT;
                    gatherIndex->main.AsBlob()->dataType   = DataType_DT_INT32;
                    gatherIndex->main.AsBlob()->dataFormat = originTensorType;
                    gatherIndex->main.AsBlob()->int32s     = {0, 3, 1, 2};
                    gatherIndex->main.AsBlob()->dims       = {4};

                    std::unique_ptr<OpT> gather(new OpT);
                    gather->outputIndexes = {(int)mNet->tensorName.size()};
                    gather->inputIndexes  = {op->inputIndexes[1], gatherIndex->outputIndexes[0]};

                    gather->type = OpType_GatherV2;
                    gather->name = op->name + "_Gather";
                    mNet->tensorName.emplace_back(gather->name);
                    tensorFormats.push_back(originTensorType);

                    op->inputIndexes[1]                       = gather->outputIndexes[0];
                    iter = mNet->oplists.insert(iter, std::move(gather));
                    iter = mNet->oplists.insert(iter, std::move(gatherIndex));
                    iter++;
                    iter++;
                    iter++;
                } else {
                    iter++;
                }
            }
        }

        for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) {
            auto& op         = *iter;
            if (op->inputIndexes.empty()) {
                iter++;
                continue;
            }
            if (!_OpNeedConvertContent(op->type)) {
                iter++;
                continue;
            }
            auto formatType  = _getFormatType(op.get(), originTensorType);
            std::vector<MNN::OpT*> transformOps;
            auto currentName         = op->name;
            for (int i = 0; i < op->inputIndexes.size(); ++i) {
                auto inputIndex = op->inputIndexes[i];
                if (inputIndex < 0) {
                    continue; // optional input, ignore it
                }
                auto type = tensorFormats[inputIndex];
                auto requireType = _getRequireFormat(formatType, i, tensorFormats[op->outputIndexes[0]], originTensorType);
                if (type == requireType) {
                    continue;
                }

                if (convertMap.find(op->inputIndexes[i]) != convertMap.end()) {
                    op->inputIndexes[i] = convertMap[op->inputIndexes[i]];
                    continue;
                }

                // Insert Transform op
                MNN::OpT* transformOp = new MNN::OpT;
                transformOps.push_back(transformOp);
                MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT;
                tc->source                  = type;
                tc->dest                    = requireType;
                transformOp->main.type      = MNN::OpParameter_TensorConvertInfo;
                transformOp->main.value     = tc;
                transformOp->name           = mNet->tensorName[inputIndex] + "___tr4" + op->name;
                // printf("Insert convert for %s, %s 's input %d\n", net->tensorName[inputIndex].c_str(),
                // op->name.c_str(), i);
                transformOp->inputIndexes.push_back(inputIndex);
                transformOp->outputIndexes.push_back(mNet->tensorName.size());
                convertMap[inputIndex] = transformOp->outputIndexes[0];
                tensorFormats.push_back(requireType);
                mNet->tensorName.push_back(transformOp->name);
                op->inputIndexes[i] = transformOp->outputIndexes[0];
                transformOp->type   = MNN::OpType_ConvertTensor;
            }
            for (int i = transformOps.size() - 1; i >= 0; i--) {
                iter = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(transformOps[i]));
            }
            for (; (*iter)->name != currentName; iter++) {
            }
            iter++;
        }

        if (originTensorType == MNN_DATA_FORMAT_NCHW) {
            return true;
        }

        // For NHWC -> NC4HW4 op, should Reset axis map
        const int axisMap[4] = {0, 2, 3, 1};

        for (auto& op : mNet->oplists) {
            if (op->inputIndexes.empty()) {
                continue;
            }
            if (tensorFormats[op->outputIndexes[0]] != MNN_DATA_FORMAT_NC4HW4) {
                continue;
            }
            if (!ctx->first_run) {
                continue;
            }
            if (MNN::OpType_Input == op->type) {
                auto input        = op->main.AsInput();
                const int dimSize = input->dims.size();
                if (dimSize > 2) {
                    const int channel = input->dims[dimSize - 1];
                    for (int i = dimSize - 1; i > 1; --i) {
                        input->dims[i] = input->dims[i - 1];
                    }
                    input->dims[1] = channel;
                }
            }
            if (MNN::OpType_Concat == op->type) {
                auto axis       = op->main.AsAxis();
                auto concatAxis = axis->axis;
                if (concatAxis < 0) {
                    concatAxis = 4 + concatAxis;
                }
                DCHECK(concatAxis >= 0 && concatAxis <= 3) << "Concat axis ERROR!";
                axis->axis = axisMap[concatAxis];
            }
            if (MNN::OpType_Permute == op->type) {
                auto permuteT = op->main.AsPermute();
                for (int i = 0; i < permuteT->dims.size(); ++i) {
                    DCHECK(permuteT->dims[i] >= 0 && permuteT->dims[i] <= 3) << "Dim Error ==> " << op->name;
                    permuteT->dims[i] = axisMap[permuteT->dims[i]];
                }
            }
            if (MNN::OpType_Slice == op->type) {
                auto slice     = op->main.AsSlice();
                auto sliceAxis = slice->axis;
                if (sliceAxis < 0) {
                    sliceAxis = 4 + sliceAxis;
                }
                DCHECK(sliceAxis >= 0 && sliceAxis <= 3) << "Slice axis ERROR!";
                slice->axis = axisMap[sliceAxis];
            }
            if (MNN::OpType_Reshape == op->type) {
                auto reshape   = op->main.AsReshape();
                auto originDim = reshape->dims;
                for (int i = 0; i < reshape->dims.size(); ++i) {
                    CHECK(i >= 0 && i <= 3) << "Error";
                    reshape->dims[axisMap[i]] = originDim[i];
                }
            }
            if (MNN::OpType_ArgMax == op->type || MNN::OpType_ArgMin == op->type) {
                auto param      = op->main.AsArgMax();
                auto originAxis = param->axis;
                DCHECK(originAxis >= 0 && originAxis <= 3) << "ArgMax / Argmin axis ERROR!";
                param->axis = axisMap[originAxis];
            }
        }
        return true;
    }