in src/contrib/msc/plugin/tensorrt_codegen.cc [643:792]
void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) {
const auto& creator_cls = CreatorCls(plugin, dynamic);
const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
if (in_declare) {
stack_.class_def(creator_cls + " : public IPluginCreator")
.class_start()
.scope_start("public:")
.constructor_def(creator_cls)
.func_def("getPluginName", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginVersion", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginNamespace", "const char*")
.func_decorator("const noexcept override")
.func_def("getFieldNames", "const PluginFieldCollection*")
.func_decorator("noexcept override")
.func_def("setPluginNamespace")
.func_decorator("noexcept override")
.func_arg("name_space", "const char*")
.func_def("createPlugin", plugin_cls + "*")
.func_decorator("noexcept override")
.func_arg("name", "const char*")
.func_arg("collection", "const PluginFieldCollection*")
.func_def("deserializePlugin", plugin_cls + "*")
.func_decorator("noexcept override")
.func_arg("name", "const char*")
.func_arg("data", "const void*")
.func_arg("length", "size_t")
.scope_end()
.scope_start("private:")
.declare("static PluginFieldCollection", "collection_")
.declare("static std::vector<PluginField>", "fields_")
.declare("std::string", "name_space_")
.scope_end()
.line()
.class_end();
} else {
const String& attr_name = MetaAttrCls(plugin);
// static members
stack_.comment("static members and register for " + plugin->name)
.declare("PluginFieldCollection", creator_cls + "::collection_")
.declare("std::vector<PluginField>", creator_cls + "::fields_")
.func_call("REGISTER_TENSORRT_PLUGIN")
.call_arg(creator_cls)
.line();
// constructor
stack_.constructor_def(creator_cls + "::" + creator_cls)
.constructor_start()
.func_call(attr_name + "_to_fields")
.call_arg("fields_");
for (const auto& t : plugin->inputs) {
stack_.func_call("emplace_back", "", "fields_")
.inplace_start("TRTUtils::ToField")
.call_arg(DocUtils::ToStr("layout_" + t->name))
.call_arg(DocUtils::ToStr("string"))
.inplace_end();
}
const auto& nb_fields_doc = DocUtils::ToAttrAccess("collection_", "nbFields");
const auto& fields_doc = DocUtils::ToAttrAccess("collection_", "fields");
stack_.func_call("size", nb_fields_doc, DocUtils::ToDoc("fields_"))
.func_call("data", fields_doc, DocUtils::ToDoc("fields_"))
.constructor_end();
// getPluginName
const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
stack_.func_def(creator_cls + "::getPluginName", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr(plugin_type));
// getPluginVersion
stack_.func_def(creator_cls + "::getPluginVersion", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr("1"));
// getPluginNamespace
stack_.func_def(creator_cls + "::getPluginNamespace", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
DocUtils::ToDoc("name_space_"))
.func_end("name");
// getFieldNames
stack_.func_def(creator_cls + "::getFieldNames", "const PluginFieldCollection*")
.func_decorator("noexcept")
.func_start()
.func_end("&collection_");
// setPluginNamespace
stack_.func_def(creator_cls + "::setPluginNamespace")
.func_decorator("noexcept")
.func_arg("name_space", "const char*")
.func_start()
.assign("name_space_", "name_space")
.func_end();
// createPlugin
size_t fields_size = plugin->attrs.size() + plugin->inputs.size();
const auto& op_cls = OpCls(plugin, dynamic);
stack_.func_def(creator_cls + "::createPlugin", plugin_cls + "*")
.func_decorator("noexcept")
.func_arg("name", "const char*")
.func_arg("collection", "const PluginFieldCollection*")
.func_start()
.line("assert(collection->nbFields == " + std::to_string(fields_size) + ");")
.assign("fields", DocUtils::ToAttrAccess(DocUtils::ToPtr("collection"), "fields"),
"const PluginField*")
.func_call(attr_name + "_from_fields", DocUtils::ToDeclare("const auto&", "meta_attr"))
.call_arg("fields")
.declare("std::vector<std::string>", "layouts")
.func_call("resize", "", "layouts")
.call_arg(plugin->inputs.size())
.for_start("i", plugin->attrs.size(), fields_size);
for (size_t i = 0; i < plugin->inputs.size(); i++) {
const auto& tensor = plugin->inputs[i];
const String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0";
if (i == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.func_call("TRTUtils::FromField")
.call_arg(DocUtils::ToIndex("fields", "i"))
.call_arg(DocUtils::ToIndex("layouts", i));
}
stack_.switch_end()
.for_end()
.func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
.call_arg("name");
for (const auto& a : plugin->attrs) {
stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name));
}
stack_.call_arg("layouts")
.func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin"))
.inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_"))
.inplace_end()
.func_end("plugin");
// deserializePlugin
stack_.func_def(creator_cls + "::deserializePlugin", plugin_cls + "*")
.func_decorator("noexcept")
.func_arg("name", "const char*")
.func_arg("data", "const void*")
.func_arg("length", "size_t")
.func_start()
.func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin"))
.call_arg("name")
.call_arg("data")
.call_arg("length")
.func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin"))
.inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_"))
.inplace_end()
.func_end("plugin");
}
}