common::Status TensorrtExecutionProvider::Compile()

in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc [1160:2041]


common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
                                                  std::vector<NodeComputeInfo>& node_compute_funcs) {
  for (const auto* fused_node : fused_nodes) {
    // Build map from input name to its index in input definitions
    std::unordered_map<std::string, size_t> input_map;
    const auto& input_defs = fused_node->InputDefs();
    input_map.reserve(input_defs.size());
    for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
      input_map[input_defs[i]->Name()] = i;
    }

    // Build map from output name to its index in output definitions
    std::unordered_map<std::string, size_t> output_map;
    const auto& output_defs = fused_node->OutputDefs();
    output_map.reserve(output_defs.size());
    for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
      output_map[output_defs[i]->Name()] = i;
    }

    // Reconstruct graph proto from fused node's function body
    const auto* func_body = fused_node->GetFunctionBody();
    if (!func_body) {
      return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty");
    }
    const Graph& graph_body = func_body->Body();
    auto graph_body_viewer = graph_body.CreateGraphViewer();
    auto model = graph_body_viewer->CreateModel(*GetLogger());
    auto model_proto = model->ToProto();
    *model_proto->mutable_graph() = *graph_body.ToGraphProto();
    model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
    std::string string_buf;
    model_proto->SerializeToString(string_buf);

    if (dump_subgraphs_) {
      // Dump TensorRT subgraphs
      std::fstream dump(fused_node->Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
      model_proto->SerializeToOstream(dump);
    }

    TensorrtLogger& trt_logger = GetTensorrtLogger();
    auto trt_builder = tensorrt_ptr::unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
    const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    auto trt_network = tensorrt_ptr::unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
    auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
    auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
    trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
    trt_config->setMaxWorkspaceSize(max_workspace_size_);

    int num_inputs = trt_network->getNbInputs();
    int num_outputs = trt_network->getNbOutputs();
    std::unordered_map<std::string, size_t> input_indexes(num_inputs);
    std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> input_shape_ranges;
    std::unordered_map<std::string, size_t> output_indexes(num_outputs);
    std::unordered_map<std::string, size_t> output_types(num_outputs);

    // Initialize shape range for dynamic shape tensors
    bool has_dynamic_shape = false;
    for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
      auto input = trt_network->getInput(i);
      const std::string& input_name = input->getName();
      nvinfer1::Dims dims = input->getDimensions();
      int nb_dims = dims.nbDims;
      if (input->isShapeTensor()) {
        // Shape tensor
        input_shape_ranges[input_name][0] = std::make_pair(INT_MAX, INT_MIN);
        has_dynamic_shape = true;
      } else {
        // Execution tensor
        for (int j = 0, end = nb_dims; j < end; ++j) {
          if (dims.d[j] == -1) {
            input_shape_ranges[input_name][j] = std::make_pair(INT_MAX, INT_MIN);
            has_dynamic_shape = true;
          }
        }
      }
    }

    // Check platform availability for low precision
    if (fp16_enable_) {
      if (!trt_builder->platformHasFastFp16()) {
        fp16_enable_ = false;
        LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16";
      }
    }

    if (int8_enable_) {
      if (!trt_builder->platformHasFastInt8()) {
        int8_enable_ = false;
        LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8";
      }
    }

    // Load INT8 calibration table
    std::unordered_map<std::string, float> dynamic_range_map;
    if (int8_enable_ && int8_calibration_cache_available_) {
      const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_);
      if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) {
        throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path);
      }
    }

    // Set precision flags
    std::string trt_node_name_with_precision = fused_node->Name();
    if (fp16_enable_ && int8_enable_) {
      trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
      trt_node_name_with_precision += "_fp16_int8";
      LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled";
    } else if (fp16_enable_) {
      trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
      trt_node_name_with_precision += "_fp16";
      LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled";
    } else if (int8_enable_) {
      trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
      trt_node_name_with_precision += "_int8";
      LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled";
    }

    // Set DLA
    if (fp16_enable_ || int8_enable_) {
      if (dla_enable_ && dla_core_ >= 0) {  //DLA can only run with FP16 and INT8
        int number_of_dla_core = trt_builder->getNbDLACores();
        if (number_of_dla_core == 0) {
          LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
          dla_enable_ = false;
        } else {
          if (dla_core_ >= number_of_dla_core) {
            LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead.";
            dla_core_ = 0;
          }
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_;
          trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
          trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
          trt_config->setDLACore(dla_core_);
          trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_);
        }
      }
    }

    // Build TRT engine here if the graph doesn't have dynamic shape input. Otherwise engine will
    // be built at runtime
    tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine> trt_engine;
    tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext> trt_context;
    if (!has_dynamic_shape) {
      const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
      const std::string engine_cache_path = cache_path + ".engine";
      std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
      if (engine_cache_enable_ && engine_file) {
        engine_file.seekg(0, std::ios::end);
        size_t engine_size = engine_file.tellg();
        engine_file.seekg(0, std::ios::beg);
        std::unique_ptr<char[]> engine_buf{new char[engine_size]};
        engine_file.read((char*)engine_buf.get(), engine_size);
        trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
        LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
        if (trt_engine == nullptr) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                 "TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
        }
      } else if (engine_decryption_enable_ && engine_cache_enable_ && !engine_file) {
        // Decrypt engine
        size_t engine_size = 0;
        if (!engine_decryption_(engine_cache_path.c_str(), nullptr, &engine_size)) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                 "TensorRT EP could not get engine buffer size");
        }
        std::unique_ptr<char[]> engine_buf{new char[engine_size]};
        if (!engine_decryption_(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                 "TensorRT EP could not call engine decryption function decrypt");
        }
        // Deserialize engine
        trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
        LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
        if (trt_engine == nullptr) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                 "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
        }
      } else {
        // Set INT8 per tensor dynamic range
        if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
          trt_config->setInt8Calibrator(nullptr);
          if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node->Name());
          }
        }

        // Build engine
        {
          auto lock = GetEngineBuildLock();
          trt_engine = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_builder->buildEngineWithConfig(*trt_network, *trt_config));
        }
        if (trt_engine == nullptr) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                 "TensorRT EP could not build engine for fused node: " + fused_node->Name());
        }
        if (engine_cache_enable_) {
          nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
          size_t engine_size = serializedModel->size();
          if (engine_decryption_enable_) {
            // Encrypt engine
            if (!engine_encryption_(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
              return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                     "TensorRT EP could not call engine encryption function encrypt");
            }
          } else {
            std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
            file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
          }
          serializedModel->destroy();
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
        }
      }

      // Build context
      trt_context = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
      if (trt_context == nullptr) {
        return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                               "TensorRT EP could not build execution context for fused node: " + fused_node->Name());
      }
    }

    // Create input to index map
    for (int i = 0; i < num_inputs; ++i) {
      auto input = trt_network->getInput(i);
      const std::string& input_name = input->getName();
      const auto& iter = input_map.find(input_name);
      if (iter != input_map.end()) {
        input_indexes[input_name] = iter->second;
      }
    }

    // Create output to index and type maps
    const auto& graph_output = model_proto->graph().output();
    for (int i = 0; i < num_outputs; ++i) {
      const std::string& output_name = trt_network->getOutput(i)->getName();
      const auto& iter = output_map.find(output_name);
      if (iter != output_map.end()) {
        output_indexes[output_name] = iter->second;
      }
      const auto& tensor_type = graph_output[i].type().tensor_type();
      output_types[output_name] = tensor_type.elem_type();
    }

    // Save engine, context and input/output info to map
    parsers_.emplace(fused_node->Name(), std::move(trt_parser));
    engines_.emplace(fused_node->Name(), std::move(trt_engine));
    contexts_.emplace(fused_node->Name(), std::move(trt_context));
    builders_.emplace(fused_node->Name(), std::move(trt_builder));
    networks_.emplace(fused_node->Name(), std::move(trt_network));
    input_info_[fused_node->Name()].push_back(input_indexes);
    output_info_[fused_node->Name()].push_back(output_indexes);
    output_info_[fused_node->Name()].push_back(output_types);
    input_shape_ranges_[fused_node->Name()] = input_shape_ranges;

    // Create function state
    // TODO: remove default capture
    NodeComputeInfo compute_info;
    compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
      std::unique_ptr<TensorrtFuncState> p = std::make_unique<TensorrtFuncState>();
      *p = {context->allocate_func, context->release_func, context->allocator_handle, &parsers_[context->node_name],
            &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
            &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
            input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
            dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
            runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
      *state = p.release();
      return 0;
    };

    // Release function state
    compute_info.release_state_func = [](FunctionState state) {
      if (state)
        delete static_cast<TensorrtFuncState*>(state);
    };
    // Create compute function
    compute_info.compute_func = [this](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
      Ort::CustomOpApi ort{*api};
      TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
      std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
      const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
      const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
      const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
      auto& shape_ranges = trt_state->input_shape_ranges;
      auto trt_builder = trt_state->builder->get();
      auto trt_engine = trt_state->engine->get();
      auto trt_context = trt_state->context->get();
      auto trt_profile = &(trt_state->trt_profile);
      auto alloc = trt_state->scratch_allocator;
      int num_inputs = static_cast<int>(input_indexes.size());
      int num_outputs = static_cast<int>(output_indexes.size());
      bool engine_update = false;
      std::unordered_set<std::string> input_names;
      std::unordered_map<std::string, std::vector<int32_t>> tensor_shape_values;

      cudaStream_t stream = static_cast<cudaStream_t>(this->GetComputeStream());

      // Load serialized engine
      const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
      const std::string engine_cache_path = cache_path + ".engine";
      const std::string profile_cache_path = cache_path + ".profile";
      if (trt_state->engine_cache_enable && trt_engine == nullptr) {
        std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
        std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
        if (engine_file && profile_file) {
          // Deserialize profile
          shape_ranges = DeserializeProfile(profile_file);
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
          // Deserialize engine
          trt_state->context->reset();
          trt_state->engine->reset();
          engine_file.seekg(0, std::ios::end);
          size_t engine_size = engine_file.tellg();
          engine_file.seekg(0, std::ios::beg);
          std::unique_ptr<char[]> engine_buf{new char[engine_size]};
          engine_file.read((char*)engine_buf.get(), engine_size);
          *(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
              trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
          if (trt_state->engine == nullptr) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
          }
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
          trt_engine = trt_state->engine->get();
          *(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
              trt_state->engine->get()->createExecutionContext());
          if (trt_state->context == nullptr) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
          }
          trt_context = trt_state->context->get();
        } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) {
          shape_ranges = DeserializeProfile(profile_file);
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
          // Decrypt engine
          size_t engine_size = 0;
          if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP could not get engine buffer size");
          }
          std::unique_ptr<char[]> engine_buf{new char[engine_size]};
          if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP could not call engine decryption function decrypt");
          }
          // Deserialize engine
          trt_state->context->reset();
          trt_state->engine->reset();
          *(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr));
          if (trt_state->engine == nullptr) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path);
          }
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
          trt_engine = trt_state->engine->get();
          *(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
              trt_state->engine->get()->createExecutionContext());
          if (trt_state->context == nullptr) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
          }
          trt_context = trt_state->context->get();
        }
      }

      for (int i = 0, end = num_inputs; i < end; ++i) {
        auto input = trt_state->network->get()->getInput(i);
        const std::string& input_name = input->getName();
        nvinfer1::Dims dims = input->getDimensions();
        int nb_dims = dims.nbDims;
        // Check and update shape ranges for dynamic shape inputs
        input_names.insert(input_name);
        if (shape_ranges.find(input_name) != shape_ranges.end()) {
          size_t input_index = 0;
          const auto& iter = input_indexes.find(input_name);
          if (iter != input_indexes.end()) {
            input_index = iter->second;
          }

          const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
          auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
          const auto& tensor_shapes = ort.GetTensorShape(tensor_info);
          auto& shape_range = shape_ranges[input_name];

          // Create shape profile
          if (input->isShapeTensor()) {
            // Get shape values for shape tensor input
            const auto& tensor_type = ort.GetTensorElementType(tensor_info);
            int shape_size = nb_dims == 0 ? 1 : static_cast<int>(tensor_shapes[0]);
            tensor_shape_values[input_name].resize(shape_size);
            switch (tensor_type) {
              case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
                int32_t* input = new int32_t[shape_size];
                CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input, ort.GetTensorData<int32_t>(input_tensor), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
                CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
                for (int j = 0; j < shape_size; ++j) {
                  tensor_shape_values[input_name][j] = input[j];
                }
                delete[] input;
                break;
              }
              case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
                int64_t* input = new int64_t[shape_size];
                CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input, ort.GetTensorData<int64_t>(input_tensor), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
                CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
                for (int j = 0; j < shape_size; ++j) {
                  tensor_shape_values[input_name][j] = static_cast<int32_t>(input[j]);
                }
                delete[] input;
                break;
              }
              default: {
                return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                       "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
              }
            }

            // Update shape ranges
            std::vector<int32_t> shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
            int shape_range_size = static_cast<int>(shape_range.size());
            if (shape_size == shape_range_size) {
              // If shape size matches, check/update shape range
              for (int j = 0; j < shape_size; ++j) {
                shapes_min[j] = static_cast<int32_t>(shape_range[j].first);
                shapes_opt[j] = static_cast<int32_t>(shape_range[j].second);
                shapes_max[j] = static_cast<int32_t>(shape_range[j].second);

                const auto& tensor_shape_value = tensor_shape_values[input_name][j];
                // Update shape range lower bound
                if (tensor_shape_value < shape_range[j].first) {
                  shape_range[j].first = tensor_shape_value;
                  shapes_min[j] = tensor_shape_value;
                  engine_update = true;
                }
                // Update shape range upper bound
                if (tensor_shape_value > shape_range[j].second) {
                  shape_range[j].second = tensor_shape_value;
                  shapes_max[j] = tensor_shape_value;
                  shapes_opt[j] = tensor_shape_value;
                  engine_update = true;
                }
              }
            } else {
              // If shape size doesn't match, initialize shape_range with the new shape value
              shape_range.clear();
              for (int j = 0; j < shape_size; ++j) {
                const auto& tensor_shape_value = tensor_shape_values[input_name][j];
                shape_range[j] = std::make_pair(tensor_shape_value, tensor_shape_value);
                shapes_min[j] = tensor_shape_value;
                shapes_opt[j] = tensor_shape_value;
                shapes_max[j] = tensor_shape_value;
              }
              engine_update = true;
            }

            if (*trt_profile == nullptr) {
              *trt_profile = trt_builder->createOptimizationProfile();
            }
            (*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
            (*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
            (*trt_profile)->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
          } else {  // Execution tensor
            nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
            for (int j = 0, end = nb_dims; j < end; ++j) {
              const auto& tensor_shape = tensor_shapes[j];
              if (shape_range.find(j) != shape_range.end()) {
                dims_min.d[j] = static_cast<int32_t>(shape_range[j].first);
                dims_opt.d[j] = static_cast<int32_t>(shape_range[j].second);
                dims_max.d[j] = static_cast<int32_t>(shape_range[j].second);

                // Update minimum dimension
                if (tensor_shape < shape_range[j].first) {
                  shape_range[j].first = tensor_shape;
                  dims_min.d[j] = static_cast<int32_t>(tensor_shape);
                  engine_update = true;
                }
                // Update maximum dimension
                if (tensor_shape > shape_range[j].second) {
                  shape_range[j].second = tensor_shape;
                  dims_max.d[j] = static_cast<int32_t>(tensor_shape);
                  dims_opt.d[j] = static_cast<int32_t>(tensor_shape);
                  engine_update = true;
                }
              }
            }

            if (*trt_profile == nullptr) {
              *trt_profile = trt_builder->createOptimizationProfile();
            }
            (*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
            (*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
            (*trt_profile)->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
          }
          ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
        }
      }

      // Regenerate engine
      // Only one profile is generated, so no need to explicitly set optimization profile
      if (engine_update) {
        trt_state->context->reset();
        trt_state->engine->reset();
        auto trt_config = tensorrt_ptr::unique_pointer<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
        trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
        trt_config->addOptimizationProfile(*trt_profile);

        // Set INT8 Per Tensor Dynamic range
        if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
          trt_config->setInt8Calibrator(nullptr);
          if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
          }
        }

        // Set precision
        if (trt_state->fp16_enable && trt_state->int8_enable) {
          trt_config->setFlags(1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast<uint32_t>(nvinfer1::BuilderFlag::kINT8));
        } else if (trt_state->fp16_enable) {
          trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
        } else if (trt_state->int8_enable) {
          trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
        }

        // Set DLA (DLA can only run with FP16 or INT8)
        if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
          trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
          trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
          trt_config->setDLACore(trt_state->dla_core);
        }

        // Build engine
        {
          auto lock = GetEngineBuildLock();
          *(trt_state->engine) = tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>(
              trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config));
        }
        if (trt_state->engine == nullptr) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
        }
        trt_engine = trt_state->engine->get();
        if (trt_state->engine_cache_enable) {
          // Serialize engine profile
          SerializeProfile(profile_cache_path, shape_ranges);
          LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;

          // Serialize engine
          nvinfer1::IHostMemory* serializedModel = trt_engine->serialize();
          size_t engine_size = serializedModel->size();
          if (trt_state->engine_decryption_enable) {
            // Encrypt engine
            if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
              return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                     "TensorRT EP could not call engine encryption function encrypt");
            }
          } else {
            std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
            file.write(reinterpret_cast<char*>(serializedModel->data()), engine_size);
          }
          serializedModel->destroy();
        }

        // Build context
        *(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
            trt_state->engine->get()->createExecutionContext());
        if (trt_state->context == nullptr) {
          return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
        }
        trt_context = trt_state->context->get();
      }

      // Get input and output binding names
      int total_bindings = trt_engine->getNbBindings();
      std::vector<void*> buffers(total_bindings);
      std::vector<std::string> input_binding_names, output_binding_names;
      for (int i = 0, end = total_bindings; i < end; ++i) {
        if (trt_engine->bindingIsInput(i)) {
          input_binding_names.push_back(trt_engine->getBindingName(i));
        } else {
          output_binding_names.push_back(trt_engine->getBindingName(i));
        }
      }

      // Set input shapes and assign input buffers
      std::vector<IAllocatorUniquePtr<void>> scratch_buffers;
      for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
        const std::string& input_name = input_binding_names[i];
        int binding_index = trt_engine->getBindingIndex(input_name.c_str());
        if (binding_index == -1) {
          continue;
        }

        size_t input_index = 0;
        const auto& iter = input_indexes.find(input_name);
        if (iter != input_indexes.end()) {
          input_index = iter->second;
        }
        const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_index);
        auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
        const auto& tensor_shapes = ort.GetTensorShape(tensor_info);

        // Set dynamic shapes
        nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast<int>(binding_index));
        int nb_dims = dimensions.nbDims;
        if (input_names.count(input_name) == 1) {
          if (trt_engine->isShapeBinding(binding_index)) {
            trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]);
          } else {
            for (int j = 0, end = nb_dims; j < end; ++j) {
              dimensions.d[j] = static_cast<int32_t>(tensor_shapes[j]);
            }
            trt_context->setBindingDimensions(binding_index, dimensions);
          }
        }

        const auto& input_type = ort.GetTensorElementType(tensor_info);
        switch (input_type) {
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
            auto input_tensor_ptr = ort.GetTensorData<float>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = const_cast<float*>(input_tensor_ptr);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
            auto input_tensor_ptr = ort.GetTensorData<uint16_t>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(uint16_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = const_cast<uint16_t*>(input_tensor_ptr);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
            auto input_tensor_ptr = ort.GetTensorData<bool>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(bool)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = const_cast<bool*>(input_tensor_ptr);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
            auto input_tensor_ptr = ort.GetTensorData<int8_t>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int8_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = const_cast<int8_t*>(input_tensor_ptr);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
            auto input_tensor_ptr = ort.GetTensorData<int32_t>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = const_cast<int32_t*>(input_tensor_ptr);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
            // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
            auto input_tensor_ptr = ort.GetTensorData<int64_t>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              SafeInt<int> input_dim_size = 1;
              for (int j = 0, end = nb_dims; j < end; ++j) {
                if (tensor_shapes[j] == 0) {
                  input_dim_size = 1;
                  break;
                } else {
                  input_dim_size *= tensor_shapes[j];
                }
              }
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, input_dim_size * sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
              cuda::Impl_Cast<int64_t, int32_t>(stream, input_tensor_ptr, reinterpret_cast<int32_t*>(buffers[binding_index]), input_dim_size);
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
            // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64
            auto input_tensor_ptr = ort.GetTensorData<double>(input_tensor);
            if (input_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              SafeInt<int> input_dim_size = 1;
              for (int j = 0, end = nb_dims; j < end; ++j) {
                if (tensor_shapes[j] == 0) {
                  input_dim_size = 1;
                  break;
                } else {
                  input_dim_size *= tensor_shapes[j];
                }
              }
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, input_dim_size * sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
              cuda::Impl_Cast<double, float>(stream, input_tensor_ptr, reinterpret_cast<float*>(buffers[binding_index]), input_dim_size);
            }
            break;
          }
          default: {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP input onnx tensor data type: " + std::to_string(input_type) + " not supported.");
          }
        }
        ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
      }

      // Set output shapes and assign output buffers
      std::vector<int> output_dim_sizes(num_outputs, 1);
      std::vector<OrtValue*> output_tensor(num_outputs, nullptr);
      for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
        // Set dynamic shapes
        const std::string& output_name = output_binding_names[i];
        int binding_index = trt_engine->getBindingIndex(output_name.c_str());
        if (binding_index == -1) {
          continue;
        }

        size_t output_index = 0;
        const auto& index_iter = output_indexes.find(output_name);
        if (index_iter != output_indexes.end()) {
          output_index = index_iter->second;
        }
        nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast<int>(binding_index));
        int nb_dims = dimensions.nbDims;
        std::vector<int64_t> output_shapes(nb_dims);
        for (int j = 0, end = nb_dims; j < end; ++j) {
          output_shapes[j] = dimensions.d[j];
        }
        output_tensor[i] = ort.KernelContext_GetOutput(context, output_index, output_shapes.data(), output_shapes.size());

        size_t output_type = 0;
        const auto& type_iter = output_types.find(output_name);
        if (type_iter != output_types.end()) {
          output_type = type_iter->second;
        }

        switch (output_type) {
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
            auto output_tensor_ptr = ort.GetTensorMutableData<float>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = output_tensor_ptr;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
            auto output_tensor_ptr = ort.GetTensorMutableData<uint16_t>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(uint16_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = output_tensor_ptr;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
            auto output_tensor_ptr = ort.GetTensorMutableData<bool>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(bool)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = output_tensor_ptr;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
            auto output_tensor_ptr = ort.GetTensorMutableData<int8_t>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int8_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = output_tensor_ptr;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
            auto output_tensor_ptr = ort.GetTensorMutableData<int32_t>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              buffers[binding_index] = output_tensor_ptr;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
            // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64
            auto output_tensor_ptr = ort.GetTensorMutableData<int64_t>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
              output_dim_sizes[i] = 1;
            } else {
              SafeInt<int> output_dim_size(output_dim_sizes[i]);
              for (int j = 0, end = nb_dims; j < end; ++j) {
                if (dimensions.d[j] == 0) {
                  output_dim_size = 1;
                  break;
                } else {
                  output_dim_size *= dimensions.d[j];
                }
              }
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, output_dim_size * sizeof(int32_t)));
              buffers[binding_index] = scratch_buffers.back().get();
              output_dim_sizes[i] = output_dim_size;
            }
            break;
          }
          case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
            // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE
            auto output_tensor_ptr = ort.GetTensorMutableData<double>(output_tensor[i]);
            if (output_tensor_ptr == nullptr) {
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
            } else {
              SafeInt<int> output_dim_size(output_dim_sizes[i]);
              for (int j = 0, end = nb_dims; j < end; ++j) {
                if (dimensions.d[j] == 0) {
                  output_dim_size = 1;
                  break;
                } else {
                  output_dim_size *= dimensions.d[j];
                }
              }
              scratch_buffers.push_back(IAllocator::MakeUniquePtr<void>(alloc, output_dim_size * sizeof(float)));
              buffers[binding_index] = scratch_buffers.back().get();
              output_dim_sizes[i] = output_dim_size;
            }
            break;
          }
          default: {
            return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                   "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.");
          }
        }
      }

      // Run TRT inference
      if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
        return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
      }

      // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
      for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
        const std::string& output_name = output_binding_names[i];
        size_t binding_index = trt_engine->getBindingIndex(output_name.c_str());
        size_t output_type = 0;
        const auto& iter = output_types.find(output_name);
        if (iter != output_types.end()) {
          output_type = iter->second;
        }
        if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
          auto output_tensor_ptr = ort.GetTensorMutableData<int64_t>(output_tensor[i]);
          if (output_tensor_ptr != nullptr) {
            cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]);
          }
        } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
          auto output_tensor_ptr = ort.GetTensorMutableData<double>(output_tensor[i]);
          if (output_tensor_ptr != nullptr) {
            cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(buffers[binding_index]), output_tensor_ptr, output_dim_sizes[i]);
          }
        }
      }
      return Status::OK();
    };

    node_compute_funcs.push_back(compute_info);
  }
  return Status::OK();
}