void TensorRTPluginCodeGen::CodegenCreator()

in src/contrib/msc/plugin/tensorrt_codegen.cc [643:792]


void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) {
  const auto& creator_cls = CreatorCls(plugin, dynamic);
  const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
  if (in_declare) {
    stack_.class_def(creator_cls + " :  public IPluginCreator")
        .class_start()
        .scope_start("public:")
        .constructor_def(creator_cls)
        .func_def("getPluginName", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getPluginVersion", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getPluginNamespace", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getFieldNames", "const PluginFieldCollection*")
        .func_decorator("noexcept override")
        .func_def("setPluginNamespace")
        .func_decorator("noexcept override")
        .func_arg("name_space", "const char*")
        .func_def("createPlugin", plugin_cls + "*")
        .func_decorator("noexcept override")
        .func_arg("name", "const char*")
        .func_arg("collection", "const PluginFieldCollection*")
        .func_def("deserializePlugin", plugin_cls + "*")
        .func_decorator("noexcept override")
        .func_arg("name", "const char*")
        .func_arg("data", "const void*")
        .func_arg("length", "size_t")
        .scope_end()
        .scope_start("private:")
        .declare("static PluginFieldCollection", "collection_")
        .declare("static std::vector<PluginField>", "fields_")
        .declare("std::string", "name_space_")
        .scope_end()
        .line()
        .class_end();
  } else {
    const String& attr_name = MetaAttrCls(plugin);
    // static members
    stack_.comment("static members and register for " + plugin->name)
        .declare("PluginFieldCollection", creator_cls + "::collection_")
        .declare("std::vector<PluginField>", creator_cls + "::fields_")
        .func_call("REGISTER_TENSORRT_PLUGIN")
        .call_arg(creator_cls)
        .line();
    // constructor
    stack_.constructor_def(creator_cls + "::" + creator_cls)
        .constructor_start()
        .func_call(attr_name + "_to_fields")
        .call_arg("fields_");
    for (const auto& t : plugin->inputs) {
      stack_.func_call("emplace_back", "", "fields_")
          .inplace_start("TRTUtils::ToField")
          .call_arg(DocUtils::ToStr("layout_" + t->name))
          .call_arg(DocUtils::ToStr("string"))
          .inplace_end();
    }
    const auto& nb_fields_doc = DocUtils::ToAttrAccess("collection_", "nbFields");
    const auto& fields_doc = DocUtils::ToAttrAccess("collection_", "fields");
    stack_.func_call("size", nb_fields_doc, DocUtils::ToDoc("fields_"))
        .func_call("data", fields_doc, DocUtils::ToDoc("fields_"))
        .constructor_end();
    // getPluginName
    const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
    stack_.func_def(creator_cls + "::getPluginName", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_end(DocUtils::ToStr(plugin_type));
    // getPluginVersion
    stack_.func_def(creator_cls + "::getPluginVersion", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_end(DocUtils::ToStr("1"));
    // getPluginNamespace
    stack_.func_def(creator_cls + "::getPluginNamespace", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
                   DocUtils::ToDoc("name_space_"))
        .func_end("name");
    // getFieldNames
    stack_.func_def(creator_cls + "::getFieldNames", "const PluginFieldCollection*")
        .func_decorator("noexcept")
        .func_start()
        .func_end("&collection_");
    // setPluginNamespace
    stack_.func_def(creator_cls + "::setPluginNamespace")
        .func_decorator("noexcept")
        .func_arg("name_space", "const char*")
        .func_start()
        .assign("name_space_", "name_space")
        .func_end();
    // createPlugin
    size_t fields_size = plugin->attrs.size() + plugin->inputs.size();
    const auto& op_cls = OpCls(plugin, dynamic);
    stack_.func_def(creator_cls + "::createPlugin", plugin_cls + "*")
        .func_decorator("noexcept")
        .func_arg("name", "const char*")
        .func_arg("collection", "const PluginFieldCollection*")
        .func_start()
        .line("assert(collection->nbFields == " + std::to_string(fields_size) + ");")
        .assign("fields", DocUtils::ToAttrAccess(DocUtils::ToPtr("collection"), "fields"),
                "const PluginField*")
        .func_call(attr_name + "_from_fields", DocUtils::ToDeclare("const auto&", "meta_attr"))
        .call_arg("fields")
        .declare("std::vector<std::string>", "layouts")
        .func_call("resize", "", "layouts")
        .call_arg(plugin->inputs.size())
        .for_start("i", plugin->attrs.size(), fields_size);
    for (size_t i = 0; i < plugin->inputs.size(); i++) {
      const auto& tensor = plugin->inputs[i];
      const String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0";
      if (i == 0) {
        stack_.switch_start(cond);
      } else {
        stack_.switch_case(cond);
      }
      stack_.func_call("TRTUtils::FromField")
          .call_arg(DocUtils::ToIndex("fields", "i"))
          .call_arg(DocUtils::ToIndex("layouts", i));
    }
    stack_.switch_end()
        .for_end()
        .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
        .call_arg("name");
    for (const auto& a : plugin->attrs) {
      stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
    }
    stack_.call_arg("layouts")
        .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin"))
        .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_"))
        .inplace_end()
        .func_end("plugin");
    // deserializePlugin
    stack_.func_def(creator_cls + "::deserializePlugin", plugin_cls + "*")
        .func_decorator("noexcept")
        .func_arg("name", "const char*")
        .func_arg("data", "const void*")
        .func_arg("length", "size_t")
        .func_start()
        .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
        .call_arg("name")
        .call_arg("data")
        .call_arg("length")
        .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin"))
        .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_"))
        .inplace_end()
        .func_end("plugin");
  }
}