in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc [1160:2041]
common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
// Build map from input name to its index in input definitions
std::unordered_map<std::string, size_t> input_map;
const auto& input_defs = fused_node->InputDefs();
input_map.reserve(input_defs.size());
for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
input_map[input_defs[i]->Name()] = i;
}
// Build map from output name to its index in output definitions
std::unordered_map<std::string, size_t> output_map;
const auto& output_defs = fused_node->OutputDefs();
output_map.reserve(output_defs.size());
for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
output_map[output_defs[i]->Name()] = i;
}
// Reconstruct graph proto from fused node's function body
const auto* func_body = fused_node->GetFunctionBody();
if (!func_body) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
}
const Graph& graph_body = func_body->Body();
auto graph_body_viewer = graph_body.CreateGraphViewer();
auto model = graph_body_viewer->CreateModel(*GetLogger());
auto model_proto = model->ToProto();
*model_proto->mutable_graph() = *graph_body.ToGraphProto();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
model_proto->SerializeToString(string_buf);
if (dump_subgraphs_) {
// Dump TensorRT subgraphs
std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = tensorrt_ptr::unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = tensorrt_ptr::unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
trt_config->setMaxWorkspaceSize(max_workspace_size_);
int num_inputs = trt_network->getNbInputs();
int num_outputs = trt_network->getNbOutputs();
std::unordered_map<std::string, size_t> input_indexes(num_inputs);
std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> input_shape_ranges;
std::unordered_map<std::string, size_t> output_indexes(num_outputs);
std::unordered_map<std::string, size_t> output_types(num_outputs);
// Initialize shape range for dynamic shape tensors
bool has_dynamic_shape = false;
for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
auto input = trt_network->getInput(i);
const std::string& input_name = input->getName();
nvinfer1::Dims dims = input->getDimensions();
int nb_dims = dims.nbDims;
if (input->isShapeTensor()) {
// Shape tensor
input_shape_ranges[input_name][0] = std::make_pair(INT_MAX, INT_MIN);
has_dynamic_shape = true;
} else {
// Execution tensor
for (int j = 0, end = nb_dims; j < end; ++j) {
if (dims.d[j] == -1) {
input_shape_ranges[input_name][j] = std::make_pair(INT_MAX, INT_MIN);
has_dynamic_shape = true;
}
}
}
}
// Check platform availability for low precision
if (fp16_enable_) {
if (!trt_builder->platformHasFastFp16()) {
fp16_enable_ = false;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16";
}
}
if (int8_enable_) {
if (!trt_builder->platformHasFastInt8()) {
int8_enable_ = false;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8";
}
}
// Load INT8 calibration table
std::unordered_map<std::string, float> dynamic_range_map;
if (int8_enable_ && int8_calibration_cache_available_) {
const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_);
if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) {
throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path);
}
}
// Set precision flags
std::string trt_node_name_with_precision = fused_node->Name();
if (fp16_enable_ && int8_enable_) {
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
trt_node_name_with_precision += "_fp16_int8";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled";
} else if (fp16_enable_) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
trt_node_name_with_precision += "_fp16";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled";
} else if (int8_enable_) {
trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
trt_node_name_with_precision += "_int8";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled";
}
// Set DLA
if (fp16_enable_ || int8_enable_) {
if (dla_enable_ && dla_core_ >= 0) { //DLA can only run with FP16 and INT8
int number_of_dla_core = trt_builder->getNbDLACores();
if (number_of_dla_core == 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
dla_enable_ = false;
} else {
if (dla_core_ >= number_of_dla_core) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead.";
dla_core_ = 0;
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_;
trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
trt_config->setDLACore(dla_core_);
trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_);
}
}
}
// Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
// be built at runtime
tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine> trt_engine;
tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext> trt_context;
if (!has_dynamic_shape) {
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
if (engine_cache_enable_ && engine_file) {
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
} else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
// Decrypt engine
size_t engine_size = 0;
if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
} else {
// Set INT8 per tensor dynamic range
if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node->Name());
}
}
// Build engine
{
auto lock = GetEngineBuildLock();
trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
}
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build engine for fused node: " + fused_node->Name());
}
if (engine_cache_enable_) {
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
size_t engine_size = serializedModel->size();
if (engine_decryption_enable_) {
// Encrypt engine
if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
}
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}
}
// Build context
trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
if (trt_context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not build execution context for fused node: " + fused_node->Name());
}
}
// Create input to index map
for (int i = 0; i < num_inputs; ++i) {
auto input = trt_network->getInput(i);
const std::string& input_name = input->getName();
const auto& iter = input_map.find(input_name);
if (iter != input_map.end()) {
input_indexes[input_name] = iter->second;
}
}
// Create output to index and type maps
const auto& graph_output = model_proto->graph().output();
for (int i = 0; i < num_outputs; ++i) {
const std::string& output_name = trt_network->getOutput(i)->getName();
const auto& iter = output_map.find(output_name);
if (iter != output_map.end()) {
output_indexes[output_name] = iter->second;
}
const auto& tensor_type = graph_output[i].type().tensor_type();
output_types[output_name] = tensor_type.elem_type();
}
// Save engine, context and input/output info to map
parsers_.emplace(fused_node->Name(), std::move(trt_parser));
engines_.emplace(fused_node->Name(), std::move(trt_engine));
contexts_.emplace(fused_node->Name(), std::move(trt_context));
builders_.emplace(fused_node->Name(), std::move(trt_builder));
networks_.emplace(fused_node->Name(), std::move(trt_network));
input_info_[fused_node->Name()].push_back(input_indexes);
output_info_[fused_node->Name()].push_back(output_indexes);
output_info_[fused_node->Name()].push_back(output_types);
input_shape_ranges_[fused_node->Name()] = input_shape_ranges;
// Create function state
// TODO: remove default capture
NodeComputeInfo compute_info;
compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
std::unique_ptr<TensorrtFuncState> p = std::make_unique<TensorrtFuncState>();
*p = {context->allocate_func, context->release_func, context->allocator_handle, &parsers_[context->node_name],
&engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
*state = p.release();
return 0;
};
// Release function state
compute_info.release_state_func = [](FunctionState state) {
if (state)
delete static_cast<TensorrtFuncState*>(state);
};
// Create compute function
compute_info.compute_func = [this](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
Ort::CustomOpApi ort{*api};
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
auto trt_profile = &(trt_state->trt_profile);
auto alloc = trt_state->scratch_allocator;
int num_inputs = static_cast<int>(input_indexes.size());
int num_outputs = static_cast<int>(output_indexes.size());
bool engine_update = false;
std::unordered_set<std::string> input_names;
std::unordered_map<std::string, std::vector<int32_t>> tensor_shape_values;
cudaStream_t stream = static_cast<cudaStream_t>(this->GetComputeStream());
// Load serialized engine
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
const std::string profile_cache_path = cache_path + ".profile";
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
if (engine_file && profile_file) {
// Deserialize profile
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
} else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
shape_ranges = DeserializeProfile(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
// Decrypt engine
size_t engine_size = 0;
if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not get engine buffer size");
}
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine decryption function decrypt");
}
// Deserialize engine
trt_state->context->reset();
trt_state->engine->reset();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
}
}
for (int i = 0, end = num_inputs; i < end; ++i) {
auto input = trt_state->network->get()->getInput(i);
const std::string& input_name = input->getName();
nvinfer1::Dims dims = input->getDimensions();
int nb_dims = dims.nbDims;
// Check and update shape ranges for dynamic shape inputs
input_names.insert(input_name);
if (shape_ranges.find(input_name) != shape_ranges.end()) {
size_t input_index = 0;
const auto& iter = input_indexes.find(input_name);
if (iter != input_indexes.end()) {
input_index = iter->second;
}
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
const auto& tensor_shapes = ort.GetTensorShape(tensor_info);
auto& shape_range = shape_ranges[input_name];
// Create shape profile
if (input->isShapeTensor()) {
// Get shape values for shape tensor input
const auto& tensor_type = ort.GetTensorElementType(tensor_info);
int shape_size = nb_dims == 0 ? 1 : static_cast<int>(tensor_shapes[0]);
tensor_shape_values[input_name].resize(shape_size);
switch (tensor_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
int32_t* input = new int32_t[shape_size];
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input, ort.GetTensorData<int32_t>(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
tensor_shape_values[input_name][j] = input[j];
}
delete[] input;
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
int64_t* input = new int64_t[shape_size];
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input, ort.GetTensorData<int64_t>(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
tensor_shape_values[input_name][j] = static_cast<int32_t>(input[j]);
}
delete[] input;
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
// Update shape ranges
std::vector<int32_t> shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
int shape_range_size = static_cast<int>(shape_range.size());
if (shape_size == shape_range_size) {
// If shape size matches, check/update shape range
for (int j = 0; j < shape_size; ++j) {
shapes_min[j] = static_cast<int32_t>(shape_range[j].first);
shapes_opt[j] = static_cast<int32_t>(shape_range[j].second);
shapes_max[j] = static_cast<int32_t>(shape_range[j].second);
const auto& tensor_shape_value = tensor_shape_values[input_name][j];
// Update shape range lower bound
if (tensor_shape_value < shape_range[j].first) {
shape_range[j].first = tensor_shape_value;
shapes_min[j] = tensor_shape_value;
engine_update = true;
}
// Update shape range upper bound
if (tensor_shape_value > shape_range[j].second) {
shape_range[j].second = tensor_shape_value;
shapes_max[j] = tensor_shape_value;
shapes_opt[j] = tensor_shape_value;
engine_update = true;
}
}
} else {
// If shape size doesn't match, initialize shape_range with the new shape value
shape_range.clear();
for (int j = 0; j < shape_size; ++j) {
const auto& tensor_shape_value = tensor_shape_values[input_name][j];
shape_range[j] = std::make_pair(tensor_shape_value, tensor_shape_value);
shapes_min[j] = tensor_shape_value;
shapes_opt[j] = tensor_shape_value;
shapes_max[j] = tensor_shape_value;
}
engine_update = true;
}
if (*trt_profile == nullptr) {
*trt_profile = trt_builder->createOptimizationProfile();
}
(*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
(*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
(*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
} else { // Execution tensor
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
for (int j = 0, end = nb_dims; j < end; ++j) {
const auto& tensor_shape = tensor_shapes[j];
if (shape_range.find(j) != shape_range.end()) {
dims_min.d[j] = static_cast<int32_t>(shape_range[j].first);
dims_opt.d[j] = static_cast<int32_t>(shape_range[j].second);
dims_max.d[j] = static_cast<int32_t>(shape_range[j].second);
// Update minimum dimension
if (tensor_shape < shape_range[j].first) {
shape_range[j].first = tensor_shape;
dims_min.d[j] = static_cast<int32_t>(tensor_shape);
engine_update = true;
}
// Update maximum dimension
if (tensor_shape > shape_range[j].second) {
shape_range[j].second = tensor_shape;
dims_max.d[j] = static_cast<int32_t>(tensor_shape);
dims_opt.d[j] = static_cast<int32_t>(tensor_shape);
engine_update = true;
}
}
}
if (*trt_profile == nullptr) {
*trt_profile = trt_builder->createOptimizationProfile();
}
(*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
(*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
(*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
}
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
}
}
// Regenerate engine
// Only one profile is generated, so no need to explicitly set optimization profile
if (engine_update) {
trt_state->context->reset();
trt_state->engine->reset();
auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
trt_config->addOptimizationProfile(*trt_profile);
// Set INT8 Per Tensor Dynamic range
if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
trt_config->setInt8Calibrator(nullptr);
if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
}
}
// Set precision
if (trt_state->fp16_enable && trt_state->int8_enable) {
trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
} else if (trt_state->fp16_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
} else if (trt_state->int8_enable) {
trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
}
// Set DLA (DLA can only run with FP16 or INT8)
if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
trt_config->setDLACore(trt_state->dla_core);
}
// Build engine
{
auto lock = GetEngineBuildLock();
*(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
}
if (trt_state->engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
trt_engine = trt_state->engine->get();
if (trt_state->engine_cache_enable) {
// Serialize engine profile
SerializeProfile(profile_cache_path, shape_ranges);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
// Serialize engine
nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine
if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
} else {
std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
}
serializedModel->destroy();
}
// Build context
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
if (trt_state->context == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
trt_context = trt_state->context->get();
}
// Get input and output binding names
int total_bindings = trt_engine->getNbBindings();
std::vector<void*> buffers(total_bindings);
std::vector<std::string> input_binding_names, output_binding_names;
for (int i = 0, end = total_bindings; i < end; ++i) {
if (trt_engine->bindingIsInput(i)) {
input_binding_names.push_back(trt_engine->getBindingName(i));
} else {
output_binding_names.push_back(trt_engine->getBindingName(i));
}
}
// Set input shapes and assign input buffers
std::vector<IAllocatorUniquePtr<void>> scratch_buffers;
for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
const std::string& input_name = input_binding_names[i];
int binding_index = trt_engine->getBindingIndex(input_name.c_str());
if (binding_index == -1) {
continue;
}
size_t input_index = 0;
const auto& iter = input_indexes.find(input_name);
if (iter != input_indexes.end()) {
input_index = iter->second;
}
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
const auto& tensor_shapes = ort.GetTensorShape(tensor_info);
// Set dynamic shapes
nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast<int>(binding_index));
int nb_dims = dimensions.nbDims;
if (input_names.count(input_name) == 1) {
if (trt_engine->isShapeBinding(binding_index)) {
trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]);
} else {
for (int j = 0, end = nb_dims; j < end; ++j) {
dimensions.d[j] = static_cast<int32_t>(tensor_shapes[j]);
}
trt_context->setBindingDimensions(binding_index, dimensions);
}
}
const auto& input_type = ort.GetTensorElementType(tensor_info);
switch (input_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
auto input_tensor_ptr = ort.GetTensorData<float>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = const_cast<float*>(input_tensor_ptr);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
auto input_tensor_ptr = ort.GetTensorData<uint16_t>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(uint16_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = const_cast<uint16_t*>(input_tensor_ptr);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
auto input_tensor_ptr = ort.GetTensorData<bool>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(bool)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = const_cast<bool*>(input_tensor_ptr);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
auto input_tensor_ptr = ort.GetTensorData<int8_t>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int8_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = const_cast<int8_t*>(input_tensor_ptr);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
auto input_tensor_ptr = ort.GetTensorData<int32_t>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = const_cast<int32_t*>(input_tensor_ptr);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
auto input_tensor_ptr = ort.GetTensorData<int64_t>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
SafeInt<int> input_dim_size = 1;
for (int j = 0, end = nb_dims; j < end; ++j) {
if (tensor_shapes[j] == 0) {
input_dim_size = 1;
break;
} else {
input_dim_size *= tensor_shapes[j];
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, input_dim_size * sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
cuda::Impl_Cast<int64_t, int32_t>(stream, input_tensor_ptr, reinterpret_cast<int32_t*>(buffers[binding_index]), input_dim_size);
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
// Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64
auto input_tensor_ptr = ort.GetTensorData<double>(input_tensor);
if (input_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
SafeInt<int> input_dim_size = 1;
for (int j = 0, end = nb_dims; j < end; ++j) {
if (tensor_shapes[j] == 0) {
input_dim_size = 1;
break;
} else {
input_dim_size *= tensor_shapes[j];
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, input_dim_size * sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
cuda::Impl_Cast<double, float>(stream, input_tensor_ptr, reinterpret_cast<float*>(buffers[binding_index]), input_dim_size);
}
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP input onnx tensor data type: " + std::to_string(input_type) + " not supported.");
}
}
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
}
// Set output shapes and assign output buffers
std::vector<int> output_dim_sizes(num_outputs, 1);
std::vector<OrtValue*> output_tensor(num_outputs, nullptr);
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
// Set dynamic shapes
const std::string& output_name = output_binding_names[i];
int binding_index = trt_engine->getBindingIndex(output_name.c_str());
if (binding_index == -1) {
continue;
}
size_t output_index = 0;
const auto& index_iter = output_indexes.find(output_name);
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast<int>(binding_index));
int nb_dims = dimensions.nbDims;
std::vector<int64_t> output_shapes(nb_dims);
for (int j = 0, end = nb_dims; j < end; ++j) {
output_shapes[j] = dimensions.d[j];
}
output_tensor[i] = ort.KernelContext_GetOutput(context, output_index, output_shapes.data(), output_shapes.size());
size_t output_type = 0;
const auto& type_iter = output_types.find(output_name);
if (type_iter != output_types.end()) {
output_type = type_iter->second;
}
switch (output_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
auto output_tensor_ptr = ort.GetTensorMutableData<float>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = output_tensor_ptr;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
auto output_tensor_ptr = ort.GetTensorMutableData<uint16_t>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(uint16_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = output_tensor_ptr;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
auto output_tensor_ptr = ort.GetTensorMutableData<bool>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(bool)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = output_tensor_ptr;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
auto output_tensor_ptr = ort.GetTensorMutableData<int8_t>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int8_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = output_tensor_ptr;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
auto output_tensor_ptr = ort.GetTensorMutableData<int32_t>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
buffers[binding_index] = output_tensor_ptr;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
// Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64
auto output_tensor_ptr = ort.GetTensorMutableData<int64_t>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
output_dim_sizes[i] = 1;
} else {
SafeInt<int> output_dim_size(output_dim_sizes[i]);
for (int j = 0, end = nb_dims; j < end; ++j) {
if (dimensions.d[j] == 0) {
output_dim_size = 1;
break;
} else {
output_dim_size *= dimensions.d[j];
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, output_dim_size * sizeof(int32_t)));
buffers[binding_index] = scratch_buffers.back().get();
output_dim_sizes[i] = output_dim_size;
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
// Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE
auto output_tensor_ptr = ort.GetTensorMutableData<double>(output_tensor[i]);
if (output_tensor_ptr == nullptr) {
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
} else {
SafeInt<int> output_dim_size(output_dim_sizes[i]);
for (int j = 0, end = nb_dims; j < end; ++j) {
if (dimensions.d[j] == 0) {
output_dim_size = 1;
break;
} else {
output_dim_size *= dimensions.d[j];
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, output_dim_size * sizeof(float)));
buffers[binding_index] = scratch_buffers.back().get();
output_dim_sizes[i] = output_dim_size;
}
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.");
}
}
}
// Run TRT inference
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
const std::string& output_name = output_binding_names[i];
size_t binding_index = trt_engine->getBindingIndex(output_name.c_str());
size_t output_type = 0;
const auto& iter = output_types.find(output_name);
if (iter != output_types.end()) {
output_type = iter->second;
}
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
auto output_tensor_ptr = ort.GetTensorMutableData<int64_t>(output_tensor[i]);
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]);
}
} else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
auto output_tensor_ptr = ort.GetTensorMutableData<double>(output_tensor[i]);
if (output_tensor_ptr != nullptr) {
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]);
}
}
}
return Status::OK();
};
node_compute_funcs.push_back(compute_info);
}
return Status::OK();
}