in src/contrib/msc/framework/tensorrt/codegen.cc [308:449]
void TensorRTCodeGen::CodeGenMain() {
stack_.line("#include \"" + graph()->name + ".h\"")
.line()
.line("using namespace nvinfer1;")
.line("using namespace tvm::contrib::msc;")
.line()
.func_def("main", "int")
.func_arg("argc", "int")
.func_arg("argv", "char**")
.func_start()
.declare("TRTLogger", "logger")
.func_call("setLogSeverity", "", "logger");
if (config()->log_level == 0) {
stack_.call_arg("ILogger::Severity::kINFO");
} else if (config()->log_level == 1) {
stack_.call_arg("ILogger::Severity::kVERBOSE");
} else {
stack_.call_arg("ILogger::Severity::kWARNING");
}
// prepare for build
stack_.comment("Define arguments")
.assign("pass", "true", "bool")
.assign("repeat_num", "1000", "int")
.assign("profile_level", std::to_string(config()->profile_level), "int")
.cond_if("argc > 1")
.assign("profile_level", "atoi(argv[1])")
.cond_end();
// start build the engine
stack_.comment("Build engine if not exist")
.cond_if("!FileUtils::FileExist(\"" + graph()->name + ".trt\")");
// create builder
stack_.comment("Create TensorRT tools")
.func_call("TRTPtr<IBuilder>", DocUtils::ToDeclare("auto", "builder"))
.func_call("createInferBuilder")
.call_arg("logger")
.pop_nest();
ReturnOnFail("builder", "Failed to create builder");
// create network
if (CompareVersion(6, 0, 0) >= 0) {
stack_
.assign("flags",
"1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)",
"uint32_t")
.func_call("TRTPtr<INetworkDefinition>", DocUtils::ToDeclare("auto", "network"))
.func_call("createNetworkV2", NullOpt, DocUtils::ToPtr("builder"))
.call_arg("flags")
.pop_nest();
} else {
stack_.func_call("TRTPtr<INetworkDefinition>", DocUtils::ToDeclare("auto", "network"))
.func_call("createNetwork", NullOpt, DocUtils::ToPtr("builder"))
.pop_nest();
}
ReturnOnFail("network", "Failed to create network");
// create config
stack_.func_call("TRTPtr<IBuilderConfig>", DocUtils::ToDeclare("auto", "config"))
.func_call("createBuilderConfig", NullOpt, DocUtils::ToPtr("builder"))
.pop_nest();
ReturnOnFail("config", "Failed to create config");
// add codegen before build
for (const auto& l : before_build_codes_) {
stack_.line(l);
}
// build model
stack_.comment("Build model")
.declare(graph()->name, "model")
.func_call("Build", "pass", "model")
.call_arg("builder")
.call_arg("network");
if (CompareVersion(6, 0, 0) >= 0) {
stack_.call_arg("config");
}
stack_.call_arg("logger");
ReturnOnFail("pass", "Failed to build model");
// add codegen after build
for (const auto& l : after_build_codes_) {
stack_.line(l);
}
// Set profile flag
stack_.comment("Set profile flag")
.declare("ProfilingVerbosity", "profile_verbose")
.cond_if("profile_level == 2")
.assign("profile_verbose", "ProfilingVerbosity::kDETAILED")
.cond_else()
.cond_if("profile_level == 1")
.assign("profile_verbose", "ProfilingVerbosity::kLAYER_NAMES_ONLY")
.cond_else()
.assign("profile_verbose", "ProfilingVerbosity::kNONE")
.cond_end()
.cond_end()
.func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtr("config"))
.call_arg("profile_verbose");
// Serialize engine
stack_.comment("Serialize engine")
.func_call("TRTUtils::SerializeEngineToFile", "pass")
.call_arg(DocUtils::ToStr(graph()->name + ".trt"))
.call_arg("builder")
.call_arg("network");
if (CompareVersion(6, 0, 0) >= 0) {
stack_.call_arg("config");
}
stack_.call_arg("logger");
ReturnOnFail("pass", "Failed to serialize the engine");
// end build the engine
stack_.cond_end();
// start deserialize engine
stack_.comment("Deserialize engine")
.declare("std::shared_ptr<ICudaEngine>", "engine")
.func_call("TRTUtils::DeserializeEngineFromFile", "pass")
.call_arg(DocUtils::ToStr(graph()->name + ".trt"))
.call_arg("engine")
.call_arg("logger");
ReturnOnFail("pass", "Failed to deserialize the engine");
// dump info by inspector
stack_.comment("Dump info by inspector")
.cond_if("profile_level > 0")
.func_call("TRTPtr<IEngineInspector>", DocUtils::ToDeclare("auto", "inspector"))
.func_call("createEngineInspector", NullOpt, DocUtils::ToPtr("engine"))
.pop_nest()
.func_call("getEngineInformation", DocUtils::ToDeclare("std::string", "result"),
DocUtils::ToPtr("inspector"))
.call_arg("LayerInformationFormat::kJSON")
.declare("std::ofstream", "os")
.declare_arg(DocUtils::ToStr(graph()->name + "_info.json"))
.declare_arg("std::ofstream::trunc")
.line("os << result << std::flush;")
.cond_end();
// test engine
if (config()->test_iter > 0) {
stack_.comment("Prepare dataset")
.declare("DatasetReader", "reader")
.declare_arg(DocUtils::ToStr(config()->dataset))
.declare_arg(config()->test_iter);
stack_.comment("Test engine by datas")
.func_call("test_" + graph()->name, "pass")
.call_arg("engine")
.call_arg("reader")
.call_arg("logger");
}
ReturnOnFail("pass", "Failed to test the engine");
stack_.func_end("pass ? 0 : 1");
}