void TensorRTCodeGen::CodeGenClassDefine()

in src/contrib/msc/framework/tensorrt/codegen.cc [95:306]


void TensorRTCodeGen::CodeGenClassDefine() {
  auto malloc_buffer = [this](const MSCTensor& tensor) {
    const String& idx_var = "idx_" + IdxTensor(tensor);
    this->stack_
        .func_call("getBindingIndex", DocUtils::ToDeclare("int", idx_var),
                   DocUtils::ToPtr("engine"))
        .call_arg(DocUtils::ToStr(tensor->name))
        .func_call("CHECK")
        .func_call("cudaMalloc")
        .call_arg(DocUtils::ToIndex("&gpu_buffers", idx_var))
        .call_arg(GetTensorBytes(tensor))
        .pop_nest()
        .func_call("malloc", DocUtils::ToIndex("cpu_buffers", idx_var))
        .call_arg(GetTensorBytes(tensor));
  };
  stack_.line("#include \"" + graph()->name + ".h\"").line();
  StartNamespace();
  // start define build method
  stack_.func_def(graph()->name + "::Build", "bool")
      .func_arg("builder", "TRTPtr<IBuilder>&")
      .func_arg("network", "TRTPtr<INetworkDefinition>&");
  if (CompareVersion(6, 0, 0) >= 0) {
    stack_.func_arg("config", "TRTPtr<IBuilderConfig>&");
  }
  stack_.func_arg("logger", "TRTLogger&").func_start();
  // save codegen before build
  if (config()->use_tools) {
    const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step");
    before_build_codes_ =
        pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag).cast<Array<String>>();
  }
  if (graph()->weight_holders.size() > 0) {
    stack_.func_call("TRTUtils::LoadWeights", "mWeights")
        .call_arg(DocUtils::ToStr(graph()->name + ".wts"));
  }
  // build layers
  for (const auto& n : graph()->node_names) {
    const auto& node = graph()->FindNode(n);
    CodeGenNode(node, config()->use_tools);
  }
  // mark outputs
  stack_.comment("Mark outputs");
  for (const auto& o : graph()->GetOutputs()) {
    const auto& pair = graph()->FindProducerAndIdx(o);
    stack_.func_call("markOutput", NullOpt, DocUtils::ToPtr("network"))
        .call_arg("*" + IdxOutputBase(pair.first, pair.second));
  }
  // mark batch_size
  stack_.comment("Mark batch size");
  stack_.func_call("createOptimizationProfile", DocUtils::ToDeclare("auto", "profile"),
                   DocUtils::ToPtr("builder"));
  Array<String> batch_flags{"MIN", "MAX", "OPT"};
  for (const auto& i : graph()->GetInputs()) {
    for (const auto& f : batch_flags) {
      stack_.func_call("setDimensions", NullOpt, DocUtils::ToPtr("profile"))
          .call_arg(DocUtils::ToStr(i->name))
          .call_arg("OptProfileSelector::k" + f)
          .call_arg(ToDims(i->shape));
    }
  }
  // set max workspace
  stack_.comment("Set max worksapce");
  if (CompareVersion(6, 0, 0) >= 0) {
    stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("config"))
        .call_arg(config()->max_workspace);
  } else {
    stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("builder"))
        .call_arg(config()->max_workspace);
  }
  // set data type
  if (config()->precision == "float16") {
    stack_.comment("Set network precision")
        .cond_if("!builder->platformHasFastFp16()")
        .func_call("log", "", "logger")
        .call_arg("ILogger::Severity::kINTERNAL_ERROR")
        .call_arg(DocUtils::ToStr("platform do not support float16, fallback to float32"))
        .cond_else()
        .func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
        .call_arg("BuilderFlag::kFP16");
    if (config()->precision_mode == "strict") {
      stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
          .call_arg("BuilderFlag::kSTRICT_TYPES");
    }
    stack_.func_call("log", "", "logger")
        .call_arg("ILogger::Severity::kINFO")
        .call_arg(DocUtils::ToStr("use float16 to build the engine"))
        .cond_end();
  } else if (config()->precision == "int8") {
    stack_.comment("Set network precision")
        .cond_if("!builder->platformHasFastInt8()")
        .func_call("log", "", "logger")
        .call_arg("ILogger::Severity::kINTERNAL_ERROR")
        .call_arg(DocUtils::ToStr("platform do not support int8, fallback to float32"))
        .cond_else()
        .func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
        .call_arg("BuilderFlag::kINT8");
    if (config()->precision_mode == "strict") {
      stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
          .call_arg("BuilderFlag::kSTRICT_TYPES");
    } else if (config()->precision_mode == "prefer") {
      stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
          .call_arg("BuilderFlag::kPREFER_PRECISION_CONSTRAINTS");
    } else if (config()->precision_mode == "obey") {
      stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config"))
          .call_arg("BuilderFlag::kOBEY_PRECISION_CONSTRAINTS");
    }
    stack_.func_call("log", "", "logger")
        .call_arg("ILogger::Severity::kINFO")
        .call_arg(DocUtils::ToStr("use int8 to build the engine"))
        .cond_end();
  }
  // save codegen after build
  if (config()->use_tools) {
    const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step");
    after_build_codes_ =
        pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag).cast<Array<String>>();
  }
  // end define build method
  stack_.func_end("true");
  // start define test function
  stack_.func_def("test_" + graph()->name, "bool")
      .func_arg("engine", "std::shared_ptr<ICudaEngine>&")
      .func_arg("reader", "DatasetReader&")
      .func_arg("logger", "TRTLogger&")
      .func_start();
  stack_.comment("Create context")
      .func_call("TRTPtr<IExecutionContext>", DocUtils::ToDeclare("auto", "context"))
      .func_call("createExecutionContext", NullOpt, DocUtils::ToPtr("engine"))
      .pop_nest();
  ReturnOnFail("context", "Failed to create the context");
  // prepare variables
  stack_.declare("bool", "pass", 0, false)
      .declare_arg("true")
      .declare("cudaStream_t", "stream")
      .func_call("CHECK")
      .func_call("cudaStreamCreate")
      .call_arg("&stream")
      .pop_nest();
  // malloc buffers
  size_t binding_num = graph()->input_names.size() + graph()->output_names.size();
  stack_.comment("Malloc and copy the buffers")
      .declare("void*", "cpu_buffers", binding_num)
      .declare("void*", "gpu_buffers", binding_num);
  for (const auto& i : graph()->GetInputs()) {
    malloc_buffer(i);
  }
  for (const auto& o : graph()->GetOutputs()) {
    malloc_buffer(o);
    stack_.declare(CppDType(o->dtype), "output_" + IdxTensor(o),
                   static_cast<size_t>(o->GetSize()->value));
  }
  // read and test datas
  stack_.comment("Read and test datas")
      .while_start("reader.ReadNext(cpu_buffers)")
      .comment("Memcopy inputs host to device");
  // copy inputs
  for (const auto& i : graph()->GetInputs()) {
    stack_.func_call("CHECK")
        .func_call("cudaMemcpyAsync")
        .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(i)))
        .call_arg(DocUtils::ToIndex("cpu_buffers", "idx_" + IdxTensor(i)))
        .call_arg(GetTensorBytes(i))
        .call_arg("cudaMemcpyHostToDevice")
        .call_arg("stream")
        .pop_nest();
  }
  // enqueue
  stack_.func_call("cudaStreamSynchronize")
      .call_arg("stream")
      .comment("enquque with gpu buffers")
      .func_call("enqueueV2", NullOpt, DocUtils::ToPtr("context"))
      .call_arg("gpu_buffers")
      .call_arg("stream")
      .call_arg("nullptr")
      .comment("Memcopy outputs device to host");
  // copy outputs
  for (const auto& o : graph()->GetOutputs()) {
    stack_.func_call("CHECK")
        .func_call("cudaMemcpyAsync")
        .call_arg("output_" + IdxTensor(o))
        .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(o)))
        .call_arg(GetTensorBytes(o))
        .call_arg("cudaMemcpyDeviceToHost")
        .call_arg("stream")
        .pop_nest();
  }
  stack_.func_call("cudaStreamSynchronize").call_arg("stream");
  // compare outputs
  for (const auto& o : graph()->GetOutputs()) {
    stack_.func_call("CommonUtils::CompareBuffers", "pass")
        .call_arg("(" + CppDType(o->dtype) + "*)cpu_buffers[idx_" + IdxTensor(o) + "]")
        .call_arg("output_" + IdxTensor(o))
        .call_arg(o->GetSize());
    ReturnOnFail("pass", "Failed to test the output " + o->name);
  }
  stack_.while_end();
  // clean up
  stack_.comment("Clean up the buffers and stream")
      .func_call("cudaStreamDestroy")
      .call_arg("stream")
      .for_start("i", 0, binding_num)
      .func_call("CHECK")
      .func_call("cudaFree")
      .call_arg(DocUtils::ToIndex("gpu_buffers", "i"))
      .pop_nest()
      .func_call("free")
      .call_arg(DocUtils::ToIndex("cpu_buffers", "i"))
      .for_end();
  // end define test method
  stack_.func_end("true");
  EndNamespace();
}