void TensorRTPluginCodeGen::CodeGenOpDefine()

in src/contrib/msc/plugin/tensorrt_codegen.cc [233:418]


void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
  if (!IsMixPrecision(plugin)) {
    // static op
    const auto& op_static = OpCls(plugin, false);
    CodegenOpCommonMethods(plugin, false, false);
    // getOutputDimensions
    stack_.func_def(op_static + "::getOutputDimensions", "Dims")
        .func_decorator("noexcept")
        .func_arg("index", "int")
        .func_arg("in_dims", "const Dims*")
        .func_arg("n_inputs", "int")
        .func_start();
    CodegenOutputInfer(plugin, false);
    stack_
        .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
                   DocUtils::ToIndex("output_metas_", "index"))
        .func_call("TRTUtils::ToDims", DocUtils::ToDeclare("Dims", "out_dims"))
        .call_arg("out_shape")
        .func_end("out_dims");
    // configureWithFormat
    stack_.func_def(op_static + "::configureWithFormat")
        .func_decorator("noexcept")
        .func_arg("in_dims", "const Dims*")
        .func_arg("n_inputs", "int")
        .func_arg("out_dims", "const Dims*")
        .func_arg("n_outputs", "int")
        .func_arg("dtype", "DataType")
        .func_arg("format", "PluginFormat")
        .func_arg("max_batch", "int")
        .func_start()
        .assign("dtype_", "dtype")
        .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
    CodegenOutputInfer(plugin, false);
    stack_.func_end();
    // supportsFormat
    stack_.func_def(op_static + "::supportsFormat", "bool")
        .func_decorator("const noexcept")
        .func_arg("dtype", "DataType")
        .func_arg("format", "PluginFormat")
        .func_start()
        .declare("bool", "support");
    size_t cnt = 0;
    for (const auto& dtypes : GetDtypeMatrix(plugin)) {
      const String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")";
      if (cnt == 0) {
        stack_.switch_start(cond);
      } else {
        stack_.switch_case(cond);
      }
      stack_.assign("support", true);
      cnt++;
    }
    stack_.switch_case().assign("support", false).switch_end().func_end("support");
    // getWorkspaceSize
    stack_.func_def(op_static + "::getWorkspaceSize", "size_t")
        .func_decorator("const noexcept")
        .func_arg("max_batch", "int")
        .func_start()
        .assign("size", 0, "size_t");
    if (plugin->externs.count("infer_buffer")) {
      CodegenBufferInfer(plugin);
    }
    stack_.func_end("size");
    // enqueue
    stack_.func_def(op_static + "::enqueue", "int")
        .func_decorator("noexcept")
        .func_arg("batch_size", "int")
        .func_arg("inputs", "const void* const*")
        .func_arg("outputs", "void* const*")
        .func_arg("workspace", "void*")
        .func_arg("stream", "cudaStream_t")
        .func_start();
    CodegenEnqueue(plugin, false);
    stack_.func_end(0);

    // static creator
    CodegenCreator(plugin, false, false);
  }
  // dynamic op
  const auto& op_dynamic = OpCls(plugin, true);
  CodegenOpCommonMethods(plugin, true, false);
  // getOutputDataType
  stack_.func_def(op_dynamic + "::getOutputDataType", "DataType")
      .func_decorator("const noexcept")
      .func_arg("index", "int")
      .func_arg("in_types", "const DataType*")
      .func_arg("n_inputs", "int")
      .func_start()
      .declare("DataType", "dtype");
  for (size_t i = 0; i < plugin->outputs.size(); i++) {
    if (i == 0) {
      stack_.switch_start("index == " + std::to_string(i));
    } else {
      stack_.switch_case("index == " + std::to_string(i));
    }
    int ref = plugin->FindDtypeRefIdx(plugin->outputs[i]);
    if (ref >= 0) {
      stack_.assign("dtype", DocUtils::ToIndex("in_types", ref));
    } else {
      stack_.func_call("TRTUtils::ToDataType", "dtype")
          .call_arg(DocUtils::ToStr(plugin->outputs[i]->dtype));
    }
  }
  stack_.switch_end().func_end("dtype");
  // getOutputDimensions
  stack_.func_def(op_dynamic + "::getOutputDimensions", "DimsExprs")
      .func_decorator("noexcept")
      .func_arg("index", "int")
      .func_arg("in_dims", "const DimsExprs*")
      .func_arg("n_inputs", "int")
      .func_arg("builder", "IExprBuilder&")
      .func_start();
  CodegenOutputInfer(plugin, false);
  stack_
      .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
                 DocUtils::ToIndex("output_metas_", "index"))
      .func_call("TRTUtils::ToDimsExprs", DocUtils::ToDeclare("DimsExprs", "out_dims"))
      .call_arg("out_shape")
      .call_arg("builder")
      .func_end("out_dims");
  // configurePlugin
  stack_.func_def(op_dynamic + "::configurePlugin")
      .func_decorator("noexcept")
      .func_arg("in_descs", "const DynamicPluginTensorDesc*")
      .func_arg("n_inputs", "int")
      .func_arg("out_descs", "const DynamicPluginTensorDesc*")
      .func_arg("n_outputs", "int")
      .func_start()
      .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
  CodegenOutputInfer(plugin, true);
  stack_.func_end();
  // supportsFormatCombination
  stack_.func_def(op_dynamic + "::supportsFormatCombination", "bool")
      .func_decorator("noexcept")
      .func_arg("pos", "int")
      .func_arg("io_desc", "const PluginTensorDesc*")
      .func_arg("n_inputs", "int")
      .func_arg("n_outputs", "int")
      .func_start()
      .declare("bool", "support");
  size_t cnt = 0;
  for (const auto& dtypes : GetDtypeMatrix(plugin)) {
    String cond;
    for (size_t i = 0; i < plugin->inputs.size(); i++) {
      cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" +
             dtypes.at(i) + "\")";
      cond = cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
    }
    if (cnt == 0) {
      stack_.switch_start(cond);
    } else {
      stack_.switch_case(cond);
    }
    stack_.assign("support", true);
    cnt++;
  }
  stack_.switch_case().assign("support", false).switch_end().func_end("support");
  // getWorkspaceSize
  stack_.func_def(op_dynamic + "::getWorkspaceSize", "size_t")
      .func_decorator("const noexcept")
      .func_arg("in_descs", "const PluginTensorDesc*")
      .func_arg("n_inputs", "int")
      .func_arg("out_descs", "const PluginTensorDesc*")
      .func_arg("n_outputs", "int")
      .func_start()
      .assign("size", 0, "size_t");
  if (plugin->externs.count("infer_buffer")) {
    CodegenBufferInfer(plugin);
  }
  stack_.func_end("size");
  // enqueue
  stack_.func_def(op_dynamic + "::enqueue", "int")
      .func_decorator("noexcept")
      .func_arg("input_descs", "const PluginTensorDesc*")
      .func_arg("output_descs", "const PluginTensorDesc*")
      .func_arg("inputs", "const void* const*")
      .func_arg("outputs", "void* const*")
      .func_arg("workspace", "void*")
      .func_arg("stream", "cudaStream_t")
      .func_start();
  CodegenEnqueue(plugin, true);
  stack_.func_end(0);

  // dynamic creator
  CodegenCreator(plugin, true, false);
}