void TreeliteModel::SetupTreeliteModule()

in src/dlr_treelite.cc [48:99]


void TreeliteModel::SetupTreeliteModule(const std::vector<std::string>& model_path) {
  ModelPath paths = SetTreelitePaths(model_path);
  // If OMP_NUM_THREADS is set, use it to determine number of threads;
  // if not, use the maximum amount of threads
  const char* val = std::getenv("OMP_NUM_THREADS");
  int num_worker_threads = (val ? std::atoi(val) : -1);
  num_inputs_ = 1;
  num_outputs_ = 1;
  // Give a dummy input name to Treelite model.
  input_names_.push_back(INPUT_NAME);
  input_types_.push_back(INPUT_TYPE);
  CHECK_EQ(TreelitePredictorLoad(paths.model_lib.c_str(), num_worker_threads, &treelite_model_), 0)
      << TreeliteGetLastError();
  CHECK_EQ(TreelitePredictorQueryNumFeature(treelite_model_, &treelite_num_feature_), 0)
      << TreeliteGetLastError();
  treelite_input_.reset(nullptr);

  const char* output_type;
  CHECK_EQ(TreelitePredictorQueryLeafOutputType(treelite_model_, &output_type), 0)
      << TreeliteGetLastError();
  CHECK_EQ(std::string(output_type), "float32")
      << "Only float32 output types are supported, got " << output_type;
  size_t num_output_class;  // > 1 for multi-class classification; 1 otherwise
  CHECK_EQ(TreelitePredictorQueryNumClass(treelite_model_, &num_output_class), 0)
      << TreeliteGetLastError();
  treelite_output_buffer_size_ = num_output_class;
  treelite_output_.empty();

  // NOTE: second dimension of the output shape is smaller than num_output_class
  //       when a multi-class classifier outputs only the class prediction
  //       (argmax) To detect this edge case, run TreelitePredictorQueryResultSize()
  DMatrixHandle tmp_matrix;
  std::vector<float> tmp_in(treelite_num_feature_);
  const float missing_value = 0.0f;
  CHECK_EQ(TreeliteDMatrixCreateFromMat(tmp_in.data(), "float32", /*num_row=*/1,
                                        treelite_num_feature_, &missing_value, &tmp_matrix),
           0)
      << TreeliteGetLastError();
  CHECK_EQ(TreelitePredictorQueryResultSize(treelite_model_, tmp_matrix, &treelite_output_size_), 0)
      << TreeliteGetLastError();
  CHECK_LE(treelite_output_size_, num_output_class) << "Precondition violated";

  UpdateInputShapes();
  has_sparse_input_ = false;
  if (!paths.metadata.empty() && !IsFileEmpty(paths.metadata)) {
    LoadJsonFromFile(paths.metadata, this->metadata_);
    ValidateDeviceTypeIfExists();
    if (metadata_.count("Model") && metadata_["Model"].count("SparseInput")) {
      has_sparse_input_ = metadata_["Model"]["SparseInput"].get<std::string>() == "1";
    }
  }
}