in tools/converter/source/common/cli.cpp [135:513]
bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv) {
cxxopts::Options options("MNNConvert");
options.positional_help("[optional args]").show_positional_help();
options.allow_unrecognised_options().add_options()(std::make_pair("h", "help"), "Convert Other Model Format To MNN Model\n")(
std::make_pair("v", "version"), "show current version")
(std::make_pair("f", "framework"),
#ifdef MNN_BUILD_TORCH
"model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,TORCH,JSON]",
#else
"model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN,JSON]",
#endif
cxxopts::value<std::string>())
(
"modelFile",
"tensorflow Pb or caffeModel, ex: *.pb,*caffemodel",
cxxopts::value<std::string>()
)
(
"batch",
"if model input's batch is not set, set as the batch size you set",
cxxopts::value<int>()
)
(
"keepInputFormat",
"keep input dimension format or not, default: true",
cxxopts::value<bool>()
)
(
"optimizeLevel",
"graph optimize option, 0: don't run optimize(only support for MNN source), 1: use graph optimize only for every input case is right, 2: normally right but some case may be wrong, default 1",
cxxopts::value<int>()
)
(
"optimizePrefer",
"graph optimize option, 0 for normal, 1 for smalleset, 2 for fastest",
cxxopts::value<int>()
)
(
"prototxt",
"only used for caffe, ex: *.prototxt",
cxxopts::value<std::string>())
(
"MNNModel",
"MNN model, ex: *.mnn",
cxxopts::value<std::string>())
(
"fp16",
"save Conv's weight/bias in half_float data type")
(
"benchmarkModel",
"Do NOT save big size data, such as Conv's weight,BN's gamma,beta,mean and variance etc. Only used to test the cost of the model")
(
"bizCode",
"MNN Model Flag, ex: MNN",
cxxopts::value<std::string>())
(
"debug",
"Enable debugging mode."
)
(
"forTraining",
"whether or not to save training ops BN and Dropout, default: false",
cxxopts::value<bool>()
)
(
"weightQuantBits",
"save conv/matmul/LSTM float weights to int8 type, only optimize for model size, 2-8 bits, default: 0, which means no weight quant",
cxxopts::value<int>()
)
(
"weightQuantAsymmetric",
"the default weight-quant uses SYMMETRIC quant method, which is compatible with old MNN versions. "
"you can try set --weightQuantAsymmetric to use asymmetric quant method to improve accuracy of the weight-quant model in some cases, "
"but asymmetric quant model cannot run on old MNN versions. You will need to upgrade MNN to new version to solve this problem. default: false",
cxxopts::value<bool>()
)
(
"weightQuantBlock",
"using block-wise weight quant, set block size, defaut: -1, which means channel-wise weight quant",
cxxopts::value<int>()
)
(
"compressionParamsFile",
"The path of the compression parameters that stores activation, "
"weight scales and zero points for quantization or information "
"for sparsity. "
"if the file does not exist, will create file base on user's option",
cxxopts::value<std::string>()
)
(
"OP",
"print framework supported op",
cxxopts::value<bool>()
)
(
"saveStaticModel",
"save static model with fix shape, default: false",
cxxopts::value<bool>()
)
(
"targetVersion",
"compability for old mnn engine, default the same as converter",
cxxopts::value<float>()
)
(
"customOpLibs",
"custom op libs ex: libmy_add.so;libmy_sub.so",
cxxopts::value<std::string>()
)
(
"info",
"dump MNN's model info"
)
(
"authCode",
"code for model authentication.",
cxxopts::value<std::string>()
)
(
"inputConfigFile",
"set input config file for static model, ex: ~/config.txt",
cxxopts::value<std::string>()
)
(
"testdir",
"set test dir, mnn will convert model and then check the result",
cxxopts::value<std::string>()
)
(
"testconfig",
"set test config json, example: tools/converter/forward.json",
cxxopts::value<std::string>()
)
(
"thredhold",
"if set test dir, thredhold mean the max rate permit for run MNN model and origin error",
cxxopts::value<float>()
)
(
"JsonFile",
"if input model is MNN and give jsonfile, while Dump MNN model to the JsonFile.",
cxxopts::value<std::string>()
)
(
"alignDenormalizedValue",
"if 1, converter would align denormalized float(|x| < 1.18e-38) as zero, because of in ubuntu/protobuf or android/flatbuf, system behaviors are different. default: 1, range: {0, 1}",
cxxopts::value<int>()
)
(
"detectSparseSpeedUp",
"if add the flag converter would detect weights sparsity and check sparse speedup, may decrease model size, but will cause more time for convert."
)
(
"saveExternalData",
"save weight to extenal bin file.",
cxxopts::value<bool>()
)
(
"useGeluApproximation",
"Use Gelu Approximation Compute Instead of use ERF",
cxxopts::value<int>()
)
(
"convertMatmulToConv",
"if 1, converter matmul with constant input to convolution. default: 1, range: {0, 1}",
cxxopts::value<int>()
)
(
"transformerFuse",
"fuse key transformer op, like attention. default: false",
cxxopts::value<bool>()
)
(
"allowCustomOp",
"allow custom op when convert. default: false",
cxxopts::value<bool>()
);
auto result = options.parse(argc, argv);
if (result.count("help")) {
std::cout << options.help({""}) << std::endl;
return false;
}
if (result.count("version")) {
std::cout << MNN_VERSION << std::endl;
return false;
}
modelPath.model = modelPath.MAX_SOURCE;
// model source
std::string frameWork;
if (result.count("framework")) {
frameWork = result["framework"].as<std::string>();
if ("TF" == frameWork) {
modelPath.model = modelConfig::TENSORFLOW;
} else if ("CAFFE" == frameWork) {
modelPath.model = modelConfig::CAFFE;
} else if ("ONNX" == frameWork) {
modelPath.model = modelConfig::ONNX;
} else if ("MNN" == frameWork) {
modelPath.model = modelConfig::MNN;
} else if ("TFLITE" == frameWork) {
modelPath.model = modelConfig::TFLITE;
#ifdef MNN_BUILD_TORCH
} else if ("TORCH" == frameWork) {
modelPath.model = modelConfig::TORCH;
#endif
} else if ("JSON" == frameWork) {
modelPath.model = modelConfig::JSON;
} else {
std::cout << "Framework Input ERROR or Not Support This Model Type Now!" << std::endl;
return false;
}
} else {
std::cout << options.help({""}) << std::endl;
DLOG(INFO) << "framework Invalid, use -f CAFFE/MNN/ONNX/TFLITE/TORCH/JSON !";
return false;
}
if (result.count("OP")) {
MNN_PRINT("Dump %s support Ops\n", frameWork.c_str());
const auto& res = OpCount::get()->getMap().find(frameWork);
if (res == OpCount::get()->getMap().end()) {
return false;
}
for (const auto& iter : res->second) {
MNN_PRINT("%s\n", iter.c_str());
}
MNN_PRINT("Total: %d\n", (int)res->second.size());
return false;
}
// model file path
if (result.count("modelFile")) {
const std::string modelFile = result["modelFile"].as<std::string>();
if (CommonKit::FileIsExist(modelFile)) {
modelPath.modelFile = modelFile;
} else {
DLOG(INFO) << "Model File Does Not Exist! ==> " << modelFile;
return false;
}
} else {
DLOG(INFO) << "modelFile Not set Invalid, use --modelFile to set!";
return false;
}
// Optimize Level
if (result.count("optimizeLevel")) {
modelPath.optimizeLevel = result["optimizeLevel"].as<int>();
if (modelPath.optimizeLevel > 1) {
DLOG(INFO) << "\n optimizeLevel > 1, some case may be wrong";
}
}
// prototxt file path
if (result.count("prototxt")) {
const std::string prototxt = result["prototxt"].as<std::string>();
if (CommonKit::FileIsExist(prototxt)) {
modelPath.prototxtFile = prototxt;
} else {
DLOG(INFO) << "Proto File Does Not Exist!";
return false;
}
} else {
// caffe model must have this option
if (modelPath.model == modelPath.CAFFE) {
DLOG(INFO) << "Proto File Not Set, use --prototxt XXX.prototxt to set it!";
return false;
}
}
// MNN model output path
if (result.count("MNNModel")) {
const std::string MNNModelPath = result["MNNModel"].as<std::string>();
modelPath.MNNModel = MNNModelPath;
} else if (result.count("JsonFile")) {
const std::string JsonFilePath = result["JsonFile"].as<std::string>();
modelPath.mnn2json = true;
modelPath.MNNModel = JsonFilePath;
} else if (result.count("info") && modelPath.model == modelConfig::MNN) {
modelPath.dumpInfo = true;
return true;
} else {
DLOG(INFO) << "MNNModel File Not Set, use --MNNModel XXX.prototxt to set it!";
return false;
}
if (result.count("targetVersion")) {
auto version = result["targetVersion"].as<float>();
std::cout << "TargetVersion is " << version << std::endl;
modelPath.targetVersion = version;
}
// add MNN bizCode
if (result.count("bizCode")) {
const std::string bizCode = result["bizCode"].as<std::string>();
modelPath.bizCode = bizCode;
} else {
MNN_ERROR("Don't has bizCode, use MNNTest for default\n");
modelPath.bizCode = "MNNTest";
}
// input config file path
if (result.count("inputConfigFile")) {
const std::string inputConfigFile = result["inputConfigFile"].as<std::string>();
modelPath.inputConfigFile = inputConfigFile;
}
// half float
if (result.count("fp16")) {
modelPath.saveHalfFloat = true;
}
if (result.count("forTraining")) {
modelPath.forTraining = true;
}
if (result.count("batch")) {
modelPath.defaultBatchSize = result["batch"].as<int>();
}
if (result.count("keepInputFormat")) {
modelPath.keepInputFormat = result["keepInputFormat"].as<bool>();
}
if (result.count("weightQuantBits")) {
modelPath.weightQuantBits = result["weightQuantBits"].as<int>();
}
if (result.count("weightQuantAsymmetric")) {
modelPath.weightQuantAsymmetric = result["weightQuantAsymmetric"].as<bool>();
}
if (result.count("weightQuantBlock")) {
modelPath.weightQuantBlock = result["weightQuantBlock"].as<int>();
}
if (result.count("saveStaticModel")) {
modelPath.saveStaticModel = true;
}
if (result.count("optimizePrefer")) {
modelPath.optimizePrefer = result["optimizePrefer"].as<int>();
}
// Int8 calibration table path.
if (result.count("compressionParamsFile")) {
modelPath.compressionParamsFile =
result["compressionParamsFile"].as<std::string>();
}
if (result.count("customOpLibs")) {
modelPath.customOpLibs = result["customOpLibs"].as<std::string>();
}
if (result.count("authCode")) {
modelPath.authCode = result["authCode"].as<std::string>();
}
if (result.count("alignDenormalizedValue")) {
modelPath.alignDenormalizedValue = result["alignDenormalizedValue"].as<int>();
}
if (result.count("detectSparseSpeedUp")) {
modelPath.detectSparseSpeedUp = true;
}
if (result.count("convertMatmulToConv")) {
modelPath.convertMatmulToConv = result["convertMatmulToConv"].as<int>();
}
if (result.count("useGeluApproximation")) {
modelPath.useGeluApproximation = result["useGeluApproximation"].as<int>();
}
if (result.count("testdir")) {
modelPath.testDir = result["testdir"].as<std::string>();
}
if (result.count("testconfig")) {
modelPath.testConfig = result["testconfig"].as<std::string>();
}
if (result.count("thredhold")) {
modelPath.testThredhold = result["thredhold"].as<float>();
}
if (result.count("saveExternalData")) {
modelPath.saveExternalData = true;
}
if (result.count("transformerFuse")) {
modelPath.transformerFuse = true;
}
if (result.count("allowCustomOp")) {
modelPath.allowCustomOp = true;
}
return true;
}