in src/contrib/msc/plugin/tensorrt_codegen.cc [233:418]
void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) {
if (!IsMixPrecision(plugin)) {
// static op
const auto& op_static = OpCls(plugin, false);
CodegenOpCommonMethods(plugin, false, false);
// getOutputDimensions
stack_.func_def(op_static + "::getOutputDimensions", "Dims")
.func_decorator("noexcept")
.func_arg("index", "int")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_start();
CodegenOutputInfer(plugin, false);
stack_
.func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
DocUtils::ToIndex("output_metas_", "index"))
.func_call("TRTUtils::ToDims", DocUtils::ToDeclare("Dims", "out_dims"))
.call_arg("out_shape")
.func_end("out_dims");
// configureWithFormat
stack_.func_def(op_static + "::configureWithFormat")
.func_decorator("noexcept")
.func_arg("in_dims", "const Dims*")
.func_arg("n_inputs", "int")
.func_arg("out_dims", "const Dims*")
.func_arg("n_outputs", "int")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_arg("max_batch", "int")
.func_start()
.assign("dtype_", "dtype")
.line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
CodegenOutputInfer(plugin, false);
stack_.func_end();
// supportsFormat
stack_.func_def(op_static + "::supportsFormat", "bool")
.func_decorator("const noexcept")
.func_arg("dtype", "DataType")
.func_arg("format", "PluginFormat")
.func_start()
.declare("bool", "support");
size_t cnt = 0;
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
const String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")";
if (cnt == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.assign("support", true);
cnt++;
}
stack_.switch_case().assign("support", false).switch_end().func_end("support");
// getWorkspaceSize
stack_.func_def(op_static + "::getWorkspaceSize", "size_t")
.func_decorator("const noexcept")
.func_arg("max_batch", "int")
.func_start()
.assign("size", 0, "size_t");
if (plugin->externs.count("infer_buffer")) {
CodegenBufferInfer(plugin);
}
stack_.func_end("size");
// enqueue
stack_.func_def(op_static + "::enqueue", "int")
.func_decorator("noexcept")
.func_arg("batch_size", "int")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.func_start();
CodegenEnqueue(plugin, false);
stack_.func_end(0);
// static creator
CodegenCreator(plugin, false, false);
}
// dynamic op
const auto& op_dynamic = OpCls(plugin, true);
CodegenOpCommonMethods(plugin, true, false);
// getOutputDataType
stack_.func_def(op_dynamic + "::getOutputDataType", "DataType")
.func_decorator("const noexcept")
.func_arg("index", "int")
.func_arg("in_types", "const DataType*")
.func_arg("n_inputs", "int")
.func_start()
.declare("DataType", "dtype");
for (size_t i = 0; i < plugin->outputs.size(); i++) {
if (i == 0) {
stack_.switch_start("index == " + std::to_string(i));
} else {
stack_.switch_case("index == " + std::to_string(i));
}
int ref = plugin->FindDtypeRefIdx(plugin->outputs[i]);
if (ref >= 0) {
stack_.assign("dtype", DocUtils::ToIndex("in_types", ref));
} else {
stack_.func_call("TRTUtils::ToDataType", "dtype")
.call_arg(DocUtils::ToStr(plugin->outputs[i]->dtype));
}
}
stack_.switch_end().func_end("dtype");
// getOutputDimensions
stack_.func_def(op_dynamic + "::getOutputDimensions", "DimsExprs")
.func_decorator("noexcept")
.func_arg("index", "int")
.func_arg("in_dims", "const DimsExprs*")
.func_arg("n_inputs", "int")
.func_arg("builder", "IExprBuilder&")
.func_start();
CodegenOutputInfer(plugin, false);
stack_
.func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"),
DocUtils::ToIndex("output_metas_", "index"))
.func_call("TRTUtils::ToDimsExprs", DocUtils::ToDeclare("DimsExprs", "out_dims"))
.call_arg("out_shape")
.call_arg("builder")
.func_end("out_dims");
// configurePlugin
stack_.func_def(op_dynamic + "::configurePlugin")
.func_decorator("noexcept")
.func_arg("in_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const DynamicPluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_start()
.line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");");
CodegenOutputInfer(plugin, true);
stack_.func_end();
// supportsFormatCombination
stack_.func_def(op_dynamic + "::supportsFormatCombination", "bool")
.func_decorator("noexcept")
.func_arg("pos", "int")
.func_arg("io_desc", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("n_outputs", "int")
.func_start()
.declare("bool", "support");
size_t cnt = 0;
for (const auto& dtypes : GetDtypeMatrix(plugin)) {
String cond;
for (size_t i = 0; i < plugin->inputs.size(); i++) {
cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" +
dtypes.at(i) + "\")";
cond = cond + (i == plugin->inputs.size() - 1 ? "" : " && ");
}
if (cnt == 0) {
stack_.switch_start(cond);
} else {
stack_.switch_case(cond);
}
stack_.assign("support", true);
cnt++;
}
stack_.switch_case().assign("support", false).switch_end().func_end("support");
// getWorkspaceSize
stack_.func_def(op_dynamic + "::getWorkspaceSize", "size_t")
.func_decorator("const noexcept")
.func_arg("in_descs", "const PluginTensorDesc*")
.func_arg("n_inputs", "int")
.func_arg("out_descs", "const PluginTensorDesc*")
.func_arg("n_outputs", "int")
.func_start()
.assign("size", 0, "size_t");
if (plugin->externs.count("infer_buffer")) {
CodegenBufferInfer(plugin);
}
stack_.func_end("size");
// enqueue
stack_.func_def(op_dynamic + "::enqueue", "int")
.func_decorator("noexcept")
.func_arg("input_descs", "const PluginTensorDesc*")
.func_arg("output_descs", "const PluginTensorDesc*")
.func_arg("inputs", "const void* const*")
.func_arg("outputs", "void* const*")
.func_arg("workspace", "void*")
.func_arg("stream", "cudaStream_t")
.func_start();
CodegenEnqueue(plugin, true);
stack_.func_end(0);
// dynamic creator
CodegenCreator(plugin, true, false);
}