in src/contrib/msc/plugin/tensorrt_codegen.cc [452:625]
void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic,
bool in_declare) {
const auto& op_cls = OpCls(plugin, dynamic);
const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
if (in_declare) {
stack_.comment("common methods for " + op_cls);
stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&");
for (const auto& a : plugin->attrs) {
stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
}
stack_.constructor_arg("layouts", "const std::vector<std::string>&")
.constructor_def(op_cls)
.constructor_arg("name", "const std::string&")
.constructor_arg("buffer", "const void*")
.constructor_arg("length", "size_t")
.assign(op_cls + "()", "delete")
.line()
.constructor_def("~" + op_cls)
.func_def("getSerializationSize", "size_t")
.func_decorator("const noexcept override")
.func_def("serialize")
.func_decorator("const noexcept override")
.func_arg("buffer", "void*")
.func_def("getPluginType", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginVersion", "const char*")
.func_decorator("const noexcept override")
.func_def("getPluginNamespace", "const char*")
.func_decorator("const noexcept override")
.func_def("getNbOutputs", "int")
.func_decorator("const noexcept override")
.func_def("setPluginNamespace")
.func_decorator("noexcept override")
.func_arg("name_space", "const char*")
.func_def("initialize", "int")
.func_decorator("noexcept override")
.func_def("terminate")
.func_decorator("noexcept override")
.func_def("destroy")
.func_decorator("noexcept override")
.func_def("clone", plugin_cls + "*")
.func_decorator("const noexcept override");
} else {
const auto& attr_name = MetaAttrCls(plugin);
// constructor from attrs
stack_.constructor_def(op_cls + "::" + op_cls).constructor_arg("name", "const std::string&");
for (const auto& a : plugin->attrs) {
stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
}
stack_.constructor_arg("layouts", "const std::vector<std::string>&")
.constructor_start()
.assign("name_", "name");
for (const auto& a : plugin->attrs) {
stack_.assign(DocUtils::ToAttrAccess("meta_attr_", a->name), a->name);
}
stack_.line("assert(layouts.size() == " + std::to_string(plugin->inputs.size()) + ");")
.assign("layouts_", "layouts");
stack_.constructor_end();
// constructor from data
stack_.constructor_def(op_cls + "::" + op_cls)
.constructor_arg("name", "const std::string&")
.constructor_arg("buffer", "const void*")
.constructor_arg("length", "size_t")
.constructor_start()
.assign("name_", "name")
.func_call("static_cast<const char*>", DocUtils::ToDeclare("const char*", "char_buf"))
.call_arg("buffer")
.assign("start_buf", "char_buf", "const char*")
.func_call(attr_name + "_deserialize", "char_buf")
.call_arg("meta_attr_")
.call_arg("char_buf")
.func_call("TRTUtils::ValFromBuffer")
.call_arg("char_buf")
.call_arg("dtype_")
.func_call("TRTUtils::ValFromBuffer")
.call_arg("char_buf")
.call_arg("layouts_")
.line("assert(layouts_.size() == " + std::to_string(plugin->inputs.size()) + ");")
.line("assert(char_buf == (start_buf + length));")
.constructor_end();
// deconstructor
stack_.constructor_def(op_cls + "::~" + op_cls)
.constructor_start()
.comment("ignore deconstruct of " + op_cls)
.constructor_end();
// getSerializationSize
stack_.func_def(op_cls + "::getSerializationSize", "size_t")
.func_decorator("const noexcept")
.func_start()
.assign("size", attr_name + "_serialize_size()", "size_t")
.assign("size", "size + sizeof(dtype_)")
.assign("size", "size + sizeof(size_t)")
.for_start("layout", "layouts_")
.assign("size", "size + sizeof(size_t) + layout.size() * sizeof(char)")
.for_end()
.func_end("size");
// serialize
stack_.func_def(op_cls + "::serialize")
.func_decorator("const noexcept")
.func_arg("buffer", "void*")
.func_start()
.func_call("static_cast<char*>", DocUtils::ToDeclare("char*", "char_buf"))
.call_arg("buffer")
.assign("start_buf", "char_buf", "const char*")
.func_call(attr_name + "_serialize", "char_buf")
.call_arg("meta_attr_")
.call_arg("char_buf")
.func_call("TRTUtils::ValToBuffer")
.call_arg("char_buf")
.call_arg("dtype_")
.func_call("TRTUtils::ValToBuffer")
.call_arg("char_buf")
.call_arg("layouts_")
.line("assert(char_buf == (start_buf + getSerializationSize()));")
.func_end();
// getPluginType
const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
stack_.func_def(op_cls + "::getPluginType", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr(plugin_type));
// getPluginVersion
stack_.func_def(op_cls + "::getPluginVersion", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_end(DocUtils::ToStr("1"));
// getPluginNamespace
stack_.func_def(op_cls + "::getPluginNamespace", "const char*")
.func_decorator("const noexcept")
.func_start()
.func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
DocUtils::ToDoc("name_space_"))
.func_end("name");
// getNbOutputs
stack_.func_def(op_cls + "::getNbOutputs", "int")
.func_decorator("const noexcept")
.func_start()
.func_end(plugin->outputs.size());
// setPluginNamespace
stack_.func_def(op_cls + "::setPluginNamespace")
.func_decorator("noexcept")
.func_arg("name_space", "const char*")
.func_start()
.assign("name_space_", "name_space")
.func_end();
// initialize
stack_.func_def(op_cls + "::initialize", "int")
.func_decorator("noexcept")
.func_start()
.func_end(0);
// terminate
stack_.func_def(op_cls + "::terminate")
.func_decorator("noexcept")
.func_start()
.comment("Ignore teminate for " + plugin->name)
.func_end();
// destroy
stack_.func_def(op_cls + "::destroy")
.func_decorator("noexcept")
.func_start()
.line("delete this;")
.func_end();
// clone
stack_.func_def(op_cls + "::clone", plugin_cls + "*")
.func_decorator("const noexcept")
.func_start()
.func_call("new " + op_cls, DocUtils::ToDeclare(plugin_cls + "*", "plugin"))
.call_arg("name_");
for (const auto& a : plugin->attrs) {
stack_.call_arg(DocUtils::ToAttrAccess("meta_attr_", a->name));
}
stack_.call_arg("layouts_").func_end("plugin");
}
}