void TensorRTCodeGen::CodeGenMain()

in src/contrib/msc/framework/tensorrt/codegen.cc [308:449]


void TensorRTCodeGen::CodeGenMain() {
  stack_.line("#include \"" + graph()->name + ".h\"")
      .line()
      .line("using namespace nvinfer1;")
      .line("using namespace tvm::contrib::msc;")
      .line()
      .func_def("main", "int")
      .func_arg("argc", "int")
      .func_arg("argv", "char**")
      .func_start()
      .declare("TRTLogger", "logger")
      .func_call("setLogSeverity", "", "logger");
  if (config()->log_level == 0) {
    stack_.call_arg("ILogger::Severity::kINFO");
  } else if (config()->log_level == 1) {
    stack_.call_arg("ILogger::Severity::kVERBOSE");
  } else {
    stack_.call_arg("ILogger::Severity::kWARNING");
  }
  // prepare for build
  stack_.comment("Define arguments")
      .assign("pass", "true", "bool")
      .assign("repeat_num", "1000", "int")
      .assign("profile_level", std::to_string(config()->profile_level), "int")
      .cond_if("argc > 1")
      .assign("profile_level", "atoi(argv[1])")
      .cond_end();

  // start build the engine
  stack_.comment("Build engine if not exist")
      .cond_if("!FileUtils::FileExist(\"" + graph()->name + ".trt\")");
  // create builder
  stack_.comment("Create TensorRT tools")
      .func_call("TRTPtr<IBuilder>", DocUtils::ToDeclare("auto", "builder"))
      .func_call("createInferBuilder")
      .call_arg("logger")
      .pop_nest();
  ReturnOnFail("builder", "Failed to create builder");
  // create network
  if (CompareVersion(6, 0, 0) >= 0) {
    stack_
        .assign("flags",
                "1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)",
                "uint32_t")
        .func_call("TRTPtr<INetworkDefinition>", DocUtils::ToDeclare("auto", "network"))
        .func_call("createNetworkV2", NullOpt, DocUtils::ToPtr("builder"))
        .call_arg("flags")
        .pop_nest();
  } else {
    stack_.func_call("TRTPtr<INetworkDefinition>", DocUtils::ToDeclare("auto", "network"))
        .func_call("createNetwork", NullOpt, DocUtils::ToPtr("builder"))
        .pop_nest();
  }
  ReturnOnFail("network", "Failed to create network");
  // create config
  stack_.func_call("TRTPtr<IBuilderConfig>", DocUtils::ToDeclare("auto", "config"))
      .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtr("builder"))
      .pop_nest();
  ReturnOnFail("config", "Failed to create config");
  // add codegen before build
  for (const auto& l : before_build_codes_) {
    stack_.line(l);
  }
  // build model
  stack_.comment("Build model")
      .declare(graph()->name, "model")
      .func_call("Build", "pass", "model")
      .call_arg("builder")
      .call_arg("network");
  if (CompareVersion(6, 0, 0) >= 0) {
    stack_.call_arg("config");
  }
  stack_.call_arg("logger");
  ReturnOnFail("pass", "Failed to build model");
  // add codegen after build
  for (const auto& l : after_build_codes_) {
    stack_.line(l);
  }
  // Set profile flag
  stack_.comment("Set profile flag")
      .declare("ProfilingVerbosity", "profile_verbose")
      .cond_if("profile_level == 2")
      .assign("profile_verbose", "ProfilingVerbosity::kDETAILED")
      .cond_else()
      .cond_if("profile_level == 1")
      .assign("profile_verbose", "ProfilingVerbosity::kLAYER_NAMES_ONLY")
      .cond_else()
      .assign("profile_verbose", "ProfilingVerbosity::kNONE")
      .cond_end()
      .cond_end()
      .func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtr("config"))
      .call_arg("profile_verbose");
  // Serialize engine
  stack_.comment("Serialize engine")
      .func_call("TRTUtils::SerializeEngineToFile", "pass")
      .call_arg(DocUtils::ToStr(graph()->name + ".trt"))
      .call_arg("builder")
      .call_arg("network");
  if (CompareVersion(6, 0, 0) >= 0) {
    stack_.call_arg("config");
  }
  stack_.call_arg("logger");
  ReturnOnFail("pass", "Failed to serialize the engine");
  // end build the engine
  stack_.cond_end();
  // start deserialize engine
  stack_.comment("Deserialize engine")
      .declare("std::shared_ptr<ICudaEngine>", "engine")
      .func_call("TRTUtils::DeserializeEngineFromFile", "pass")
      .call_arg(DocUtils::ToStr(graph()->name + ".trt"))
      .call_arg("engine")
      .call_arg("logger");
  ReturnOnFail("pass", "Failed to deserialize the engine");
  // dump info by inspector
  stack_.comment("Dump info by inspector")
      .cond_if("profile_level > 0")
      .func_call("TRTPtr<IEngineInspector>", DocUtils::ToDeclare("auto", "inspector"))
      .func_call("createEngineInspector", NullOpt, DocUtils::ToPtr("engine"))
      .pop_nest()
      .func_call("getEngineInformation", DocUtils::ToDeclare("std::string", "result"),
                 DocUtils::ToPtr("inspector"))
      .call_arg("LayerInformationFormat::kJSON")
      .declare("std::ofstream", "os")
      .declare_arg(DocUtils::ToStr(graph()->name + "_info.json"))
      .declare_arg("std::ofstream::trunc")
      .line("os << result << std::flush;")
      .cond_end();
  // test engine
  if (config()->test_iter > 0) {
    stack_.comment("Prepare dataset")
        .declare("DatasetReader", "reader")
        .declare_arg(DocUtils::ToStr(config()->dataset))
        .declare_arg(config()->test_iter);
    stack_.comment("Test engine by datas")
        .func_call("test_" + graph()->name, "pass")
        .call_arg("engine")
        .call_arg("reader")
        .call_arg("logger");
  }
  ReturnOnFail("pass", "Failed to test the engine");
  stack_.func_end("pass ? 0 : 1");
}