bool Cli::initializeMNNConvertArgs()

in tools/converter/source/common/cli.cpp [135:513]


bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv) {
    cxxopts::Options options("MNNConvert");

    options.positional_help("[optional args]").show_positional_help();

    options.allow_unrecognised_options().add_options()(std::make_pair("h", "help"), "Convert Other Model Format To MNN Model\n")(
                                                                                                                                 std::make_pair("v", "version"), "show current version")
    (std::make_pair("f", "framework"),
#ifdef MNN_BUILD_TORCH
     "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,TORCH,JSON]",
#else
     "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,JSON]",
#endif
     cxxopts::value<std::string>())
    (
     "modelFile",
     "tensorflow Pb or caffeModel, ex: *.pb,*caffemodel",
     cxxopts::value<std::string>()
     )
    (
     "batch",
     "if model input's batch is not set, set as the batch size you set",
     cxxopts::value<int>()
     )
    (
     "keepInputFormat",
     "keep input dimension format or not, default: true",
     cxxopts::value<bool>()
     )
    (
     "optimizeLevel",
     "graph optimize option, 0: don't run optimize(only support for MNN source), 1: use graph optimize only for every input case is right, 2: normally right but some case may be wrong, default 1",
     cxxopts::value<int>()
     )
    (
     "optimizePrefer",
     "graph optimize option, 0 for normal, 1 for smalleset, 2 for fastest",
     cxxopts::value<int>()
     )
    (
     "prototxt",
     "only used for caffe, ex: *.prototxt",
     cxxopts::value<std::string>())
    (
     "MNNModel",
     "MNN model, ex: *.mnn",
     cxxopts::value<std::string>())
    (
     "fp16",
     "save Conv's weight/bias in half_float data type")
    (
     "benchmarkModel",
     "Do NOT save big size data, such as Conv's weight,BN's gamma,beta,mean and variance etc. Only used to test the cost of the model")
    (
     "bizCode",
     "MNN Model Flag, ex: MNN",
     cxxopts::value<std::string>())
    (
     "debug",
     "Enable debugging mode."
     )
    (
     "forTraining",
     "whether or not to save training ops BN and Dropout, default: false",
     cxxopts::value<bool>()
     )
    (
     "weightQuantBits",
     "save conv/matmul/LSTM float weights to int8 type, only optimize for model size, 2-8 bits, default: 0, which means no weight quant",
     cxxopts::value<int>()
     )
    (
     "weightQuantAsymmetric",
     "the default weight-quant uses SYMMETRIC quant method, which is compatible with old MNN versions. "
     "you can try set --weightQuantAsymmetric to use asymmetric quant method to improve accuracy of the weight-quant model in some cases, "
     "but asymmetric quant model cannot run on old MNN versions. You will need to upgrade MNN to new version to solve this problem. default: false",
     cxxopts::value<bool>()
     )
    (
     "weightQuantBlock",
     "using block-wise weight quant, set block size, defaut: -1, which means channel-wise weight quant",
     cxxopts::value<int>()
     )
    (
     "compressionParamsFile",
     "The path of the compression parameters that stores activation, "
     "weight scales and zero points for quantization or information "
     "for sparsity. "
     "if the file does not exist, will create file base on user's option",
     cxxopts::value<std::string>()
     )
    (
     "OP",
     "print framework supported op",
     cxxopts::value<bool>()
     )
    (
     "saveStaticModel",
     "save static model with fix shape, default: false",
     cxxopts::value<bool>()
     )
    (
     "targetVersion",
     "compability for old mnn engine, default the same as converter",
     cxxopts::value<float>()
     )
    (
     "customOpLibs",
     "custom op libs ex: libmy_add.so;libmy_sub.so",
     cxxopts::value<std::string>()
     )
    (
     "info",
     "dump MNN's model info"
     )
    (
     "authCode",
     "code for model authentication.",
     cxxopts::value<std::string>()
     )
    (
     "inputConfigFile",
     "set input config file for static model, ex: ~/config.txt",
     cxxopts::value<std::string>()
     )
    (
     "testdir",
     "set test dir, mnn will convert model and then check the result",
     cxxopts::value<std::string>()
     )
    (
     "testconfig",
     "set test config json, example: tools/converter/forward.json",
     cxxopts::value<std::string>()
     )
    (
     "thredhold",
     "if set test dir, thredhold mean the max rate permit for run MNN model and origin error",
     cxxopts::value<float>()
     )
    (
     "JsonFile",
     "if input model is MNN and give jsonfile, while Dump MNN model to the JsonFile.",
     cxxopts::value<std::string>()
     )
    (
     "alignDenormalizedValue",
     "if 1, converter would align denormalized float(|x| < 1.18e-38) as zero, because of in ubuntu/protobuf or android/flatbuf, system behaviors are different. default: 1, range: {0, 1}",
     cxxopts::value<int>()
     )
    (
     "detectSparseSpeedUp",
     "if add the flag converter would detect weights sparsity and check sparse speedup, may decrease model size, but will cause more time for convert."
     )
    (
     "saveExternalData",
     "save weight to extenal bin file.",
     cxxopts::value<bool>()
     )
    (
     "useGeluApproximation",
     "Use Gelu Approximation Compute Instead of use ERF",
     cxxopts::value<int>()
     )
    (
     "convertMatmulToConv",
     "if 1, converter matmul with constant input to convolution. default: 1, range: {0, 1}",
     cxxopts::value<int>()
     )
    (
     "transformerFuse",
     "fuse key transformer op, like attention. default: false",
     cxxopts::value<bool>()
     )
     (
     "allowCustomOp",
     "allow custom op when convert. default: false",
     cxxopts::value<bool>()
     );

    auto result = options.parse(argc, argv);

    if (result.count("help")) {
        std::cout << options.help({""}) << std::endl;
        return false;
    }

    if (result.count("version")) {
        std::cout << MNN_VERSION << std::endl;
        return false;
    }

    modelPath.model = modelPath.MAX_SOURCE;
    // model source
    std::string frameWork;
    if (result.count("framework")) {
        frameWork = result["framework"].as<std::string>();
        if ("TF" == frameWork) {
            modelPath.model = modelConfig::TENSORFLOW;
        } else if ("CAFFE" == frameWork) {
            modelPath.model = modelConfig::CAFFE;
        } else if ("ONNX" == frameWork) {
            modelPath.model = modelConfig::ONNX;
        } else if ("MNN" == frameWork) {
            modelPath.model = modelConfig::MNN;
        } else if ("TFLITE" == frameWork) {
            modelPath.model = modelConfig::TFLITE;
#ifdef MNN_BUILD_TORCH
        } else if ("TORCH" == frameWork) {
            modelPath.model = modelConfig::TORCH;
#endif
        } else if ("JSON" == frameWork) {
            modelPath.model = modelConfig::JSON;
        } else {
            std::cout << "Framework Input ERROR or Not Support This Model Type Now!" << std::endl;
            return false;
        }
    } else {
        std::cout << options.help({""}) << std::endl;
        DLOG(INFO) << "framework Invalid, use -f CAFFE/MNN/ONNX/TFLITE/TORCH/JSON !";
        return false;
    }
    if (result.count("OP")) {
        MNN_PRINT("Dump %s support Ops\n", frameWork.c_str());
        const auto& res = OpCount::get()->getMap().find(frameWork);
        if (res == OpCount::get()->getMap().end()) {
            return false;
        }
        for (const auto& iter : res->second) {
            MNN_PRINT("%s\n", iter.c_str());
        }
        MNN_PRINT("Total: %d\n", (int)res->second.size());
        return false;
    }

    // model file path
    if (result.count("modelFile")) {
        const std::string modelFile = result["modelFile"].as<std::string>();
        if (CommonKit::FileIsExist(modelFile)) {
            modelPath.modelFile = modelFile;
        } else {
            DLOG(INFO) << "Model File Does Not Exist! ==> " << modelFile;
            return false;
        }
    } else {
        DLOG(INFO) << "modelFile Not set Invalid, use --modelFile to set!";
        return false;
    }
    // Optimize Level
    if (result.count("optimizeLevel")) {
        modelPath.optimizeLevel = result["optimizeLevel"].as<int>();
        if (modelPath.optimizeLevel > 1) {
            DLOG(INFO) << "\n optimizeLevel > 1, some case may be wrong";
        }
    }

    // prototxt file path
    if (result.count("prototxt")) {
        const std::string prototxt = result["prototxt"].as<std::string>();
        if (CommonKit::FileIsExist(prototxt)) {
            modelPath.prototxtFile = prototxt;
        } else {
            DLOG(INFO) << "Proto File Does Not Exist!";
            return false;
        }
    } else {
        // caffe model must have this option
        if (modelPath.model == modelPath.CAFFE) {
            DLOG(INFO) << "Proto File Not Set, use --prototxt XXX.prototxt to set it!";
            return false;
        }
    }

    // MNN model output path
    if (result.count("MNNModel")) {
        const std::string MNNModelPath = result["MNNModel"].as<std::string>();
        modelPath.MNNModel             = MNNModelPath;
    } else if (result.count("JsonFile")) {
        const std::string JsonFilePath = result["JsonFile"].as<std::string>();
        modelPath.mnn2json             = true;
        modelPath.MNNModel             = JsonFilePath;
    } else if (result.count("info") && modelPath.model == modelConfig::MNN) {
        modelPath.dumpInfo = true;
        return true;
    } else {
        DLOG(INFO) << "MNNModel File Not Set, use --MNNModel XXX.prototxt to set it!";
        return false;
    }
    if (result.count("targetVersion")) {
        auto version = result["targetVersion"].as<float>();
        std::cout << "TargetVersion is " << version << std::endl;
        modelPath.targetVersion = version;
    }
    // add MNN bizCode
    if (result.count("bizCode")) {
        const std::string bizCode = result["bizCode"].as<std::string>();
        modelPath.bizCode         = bizCode;
    } else {
        MNN_ERROR("Don't has bizCode, use MNNTest for default\n");
        modelPath.bizCode = "MNNTest";
    }

    // input config file path
    if (result.count("inputConfigFile")) {
        const std::string inputConfigFile = result["inputConfigFile"].as<std::string>();
        modelPath.inputConfigFile         = inputConfigFile;
    }

    // half float
    if (result.count("fp16")) {
        modelPath.saveHalfFloat = true;
    }
    if (result.count("forTraining")) {
        modelPath.forTraining = true;
    }
    if (result.count("batch")) {
        modelPath.defaultBatchSize = result["batch"].as<int>();
    }
    if (result.count("keepInputFormat")) {
        modelPath.keepInputFormat = result["keepInputFormat"].as<bool>();
    }
    if (result.count("weightQuantBits")) {
        modelPath.weightQuantBits = result["weightQuantBits"].as<int>();
    }
    if (result.count("weightQuantAsymmetric")) {
        modelPath.weightQuantAsymmetric = result["weightQuantAsymmetric"].as<bool>();
    }
    if (result.count("weightQuantBlock")) {
        modelPath.weightQuantBlock = result["weightQuantBlock"].as<int>();
    }
    if (result.count("saveStaticModel")) {
        modelPath.saveStaticModel = true;
    }
    if (result.count("optimizePrefer")) {
        modelPath.optimizePrefer = result["optimizePrefer"].as<int>();
    }
    // Int8 calibration table path.
    if (result.count("compressionParamsFile")) {
        modelPath.compressionParamsFile =
        result["compressionParamsFile"].as<std::string>();
    }
    if (result.count("customOpLibs")) {
        modelPath.customOpLibs = result["customOpLibs"].as<std::string>();
    }
    if (result.count("authCode")) {
        modelPath.authCode = result["authCode"].as<std::string>();
    }
    if (result.count("alignDenormalizedValue")) {
        modelPath.alignDenormalizedValue = result["alignDenormalizedValue"].as<int>();
    }
    if (result.count("detectSparseSpeedUp")) {
        modelPath.detectSparseSpeedUp = true;
    }
    if (result.count("convertMatmulToConv")) {
        modelPath.convertMatmulToConv = result["convertMatmulToConv"].as<int>();
    }
    if (result.count("useGeluApproximation")) {
        modelPath.useGeluApproximation = result["useGeluApproximation"].as<int>();
    }
    if (result.count("testdir")) {
        modelPath.testDir = result["testdir"].as<std::string>();
    }
    if (result.count("testconfig")) {
        modelPath.testConfig = result["testconfig"].as<std::string>();
    }
    if (result.count("thredhold")) {
        modelPath.testThredhold = result["thredhold"].as<float>();
    }
    if (result.count("saveExternalData")) {
        modelPath.saveExternalData = true;
    }
    if (result.count("transformerFuse")) {
        modelPath.transformerFuse = true;
    }
    if (result.count("allowCustomOp")) {
        modelPath.allowCustomOp = true;
    }
    return true;
}