void TensorRTPluginCodeGen::CodegenOpCommonMethods()

in src/contrib/msc/plugin/tensorrt_codegen.cc [452:625]


void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic,
                                                   bool in_declare) {
  const auto& op_cls = OpCls(plugin, dynamic);
  const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2";
  if (in_declare) {
    stack_.comment("common methods for " + op_cls);
    stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&");
    for (const auto& a : plugin->attrs) {
      stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
    }
    stack_.constructor_arg("layouts", "const std::vector<std::string>&")
        .constructor_def(op_cls)
        .constructor_arg("name", "const std::string&")
        .constructor_arg("buffer", "const void*")
        .constructor_arg("length", "size_t")
        .assign(op_cls + "()", "delete")
        .line()
        .constructor_def("~" + op_cls)
        .func_def("getSerializationSize", "size_t")
        .func_decorator("const noexcept override")
        .func_def("serialize")
        .func_decorator("const noexcept override")
        .func_arg("buffer", "void*")
        .func_def("getPluginType", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getPluginVersion", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getPluginNamespace", "const char*")
        .func_decorator("const noexcept override")
        .func_def("getNbOutputs", "int")
        .func_decorator("const noexcept override")
        .func_def("setPluginNamespace")
        .func_decorator("noexcept override")
        .func_arg("name_space", "const char*")
        .func_def("initialize", "int")
        .func_decorator("noexcept override")
        .func_def("terminate")
        .func_decorator("noexcept override")
        .func_def("destroy")
        .func_decorator("noexcept override")
        .func_def("clone", plugin_cls + "*")
        .func_decorator("const noexcept override");
  } else {
    const auto& attr_name = MetaAttrCls(plugin);
    // constructor from attrs
    stack_.constructor_def(op_cls + "::" + op_cls).constructor_arg("name", "const std::string&");
    for (const auto& a : plugin->attrs) {
      stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&");
    }
    stack_.constructor_arg("layouts", "const std::vector<std::string>&")
        .constructor_start()
        .assign("name_", "name");
    for (const auto& a : plugin->attrs) {
      stack_.assign(DocUtils::ToAttrAccess("meta_attr_", a->name), a->name);
    }
    stack_.line("assert(layouts.size() == " + std::to_string(plugin->inputs.size()) + ");")
        .assign("layouts_", "layouts");
    stack_.constructor_end();
    // constructor from data
    stack_.constructor_def(op_cls + "::" + op_cls)
        .constructor_arg("name", "const std::string&")
        .constructor_arg("buffer", "const void*")
        .constructor_arg("length", "size_t")
        .constructor_start()
        .assign("name_", "name")
        .func_call("static_cast<const char*>", DocUtils::ToDeclare("const char*", "char_buf"))
        .call_arg("buffer")
        .assign("start_buf", "char_buf", "const char*")
        .func_call(attr_name + "_deserialize", "char_buf")
        .call_arg("meta_attr_")
        .call_arg("char_buf")
        .func_call("TRTUtils::ValFromBuffer")
        .call_arg("char_buf")
        .call_arg("dtype_")
        .func_call("TRTUtils::ValFromBuffer")
        .call_arg("char_buf")
        .call_arg("layouts_")
        .line("assert(layouts_.size() == " + std::to_string(plugin->inputs.size()) + ");")
        .line("assert(char_buf == (start_buf + length));")
        .constructor_end();
    // deconstructor
    stack_.constructor_def(op_cls + "::~" + op_cls)
        .constructor_start()
        .comment("ignore deconstruct of " + op_cls)
        .constructor_end();
    // getSerializationSize
    stack_.func_def(op_cls + "::getSerializationSize", "size_t")
        .func_decorator("const noexcept")
        .func_start()
        .assign("size", attr_name + "_serialize_size()", "size_t")
        .assign("size", "size + sizeof(dtype_)")
        .assign("size", "size + sizeof(size_t)")
        .for_start("layout", "layouts_")
        .assign("size", "size + sizeof(size_t) + layout.size() * sizeof(char)")
        .for_end()
        .func_end("size");
    // serialize
    stack_.func_def(op_cls + "::serialize")
        .func_decorator("const noexcept")
        .func_arg("buffer", "void*")
        .func_start()
        .func_call("static_cast<char*>", DocUtils::ToDeclare("char*", "char_buf"))
        .call_arg("buffer")
        .assign("start_buf", "char_buf", "const char*")
        .func_call(attr_name + "_serialize", "char_buf")
        .call_arg("meta_attr_")
        .call_arg("char_buf")
        .func_call("TRTUtils::ValToBuffer")
        .call_arg("char_buf")
        .call_arg("dtype_")
        .func_call("TRTUtils::ValToBuffer")
        .call_arg("char_buf")
        .call_arg("layouts_")
        .line("assert(char_buf == (start_buf + getSerializationSize()));")
        .func_end();
    // getPluginType
    const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : "");
    stack_.func_def(op_cls + "::getPluginType", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_end(DocUtils::ToStr(plugin_type));
    // getPluginVersion
    stack_.func_def(op_cls + "::getPluginVersion", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_end(DocUtils::ToStr("1"));
    // getPluginNamespace
    stack_.func_def(op_cls + "::getPluginNamespace", "const char*")
        .func_decorator("const noexcept")
        .func_start()
        .func_call("c_str", DocUtils::ToDeclare("const char*", "name"),
                   DocUtils::ToDoc("name_space_"))
        .func_end("name");
    // getNbOutputs
    stack_.func_def(op_cls + "::getNbOutputs", "int")
        .func_decorator("const noexcept")
        .func_start()
        .func_end(plugin->outputs.size());
    // setPluginNamespace
    stack_.func_def(op_cls + "::setPluginNamespace")
        .func_decorator("noexcept")
        .func_arg("name_space", "const char*")
        .func_start()
        .assign("name_space_", "name_space")
        .func_end();
    // initialize
    stack_.func_def(op_cls + "::initialize", "int")
        .func_decorator("noexcept")
        .func_start()
        .func_end(0);
    // terminate
    stack_.func_def(op_cls + "::terminate")
        .func_decorator("noexcept")
        .func_start()
        .comment("Ignore teminate for " + plugin->name)
        .func_end();
    // destroy
    stack_.func_def(op_cls + "::destroy")
        .func_decorator("noexcept")
        .func_start()
        .line("delete this;")
        .func_end();
    // clone
    stack_.func_def(op_cls + "::clone", plugin_cls + "*")
        .func_decorator("const noexcept")
        .func_start()
        .func_call("new " + op_cls, DocUtils::ToDeclare(plugin_cls + "*", "plugin"))
        .call_arg("name_");
    for (const auto& a : plugin->attrs) {
      stack_.call_arg(DocUtils::ToAttrAccess("meta_attr_", a->name));
    }
    stack_.call_arg("layouts_").func_end("plugin");
  }
}