in src/dlr_treelite.cc [48:99]
void TreeliteModel::SetupTreeliteModule(const std::vector<std::string>& model_path) {
ModelPath paths = SetTreelitePaths(model_path);
// If OMP_NUM_THREADS is set, use it to determine number of threads;
// if not, use the maximum amount of threads
const char* val = std::getenv("OMP_NUM_THREADS");
int num_worker_threads = (val ? std::atoi(val) : -1);
num_inputs_ = 1;
num_outputs_ = 1;
// Give a dummy input name to Treelite model.
input_names_.push_back(INPUT_NAME);
input_types_.push_back(INPUT_TYPE);
CHECK_EQ(TreelitePredictorLoad(paths.model_lib.c_str(), num_worker_threads, &treelite_model_), 0)
<< TreeliteGetLastError();
CHECK_EQ(TreelitePredictorQueryNumFeature(treelite_model_, &treelite_num_feature_), 0)
<< TreeliteGetLastError();
treelite_input_.reset(nullptr);
const char* output_type;
CHECK_EQ(TreelitePredictorQueryLeafOutputType(treelite_model_, &output_type), 0)
<< TreeliteGetLastError();
CHECK_EQ(std::string(output_type), "float32")
<< "Only float32 output types are supported, got " << output_type;
size_t num_output_class; // > 1 for multi-class classification; 1 otherwise
CHECK_EQ(TreelitePredictorQueryNumClass(treelite_model_, &num_output_class), 0)
<< TreeliteGetLastError();
treelite_output_buffer_size_ = num_output_class;
treelite_output_.empty();
// NOTE: second dimension of the output shape is smaller than num_output_class
// when a multi-class classifier outputs only the class prediction
// (argmax) To detect this edge case, run TreelitePredictorQueryResultSize()
DMatrixHandle tmp_matrix;
std::vector<float> tmp_in(treelite_num_feature_);
const float missing_value = 0.0f;
CHECK_EQ(TreeliteDMatrixCreateFromMat(tmp_in.data(), "float32", /*num_row=*/1,
treelite_num_feature_, &missing_value, &tmp_matrix),
0)
<< TreeliteGetLastError();
CHECK_EQ(TreelitePredictorQueryResultSize(treelite_model_, tmp_matrix, &treelite_output_size_), 0)
<< TreeliteGetLastError();
CHECK_LE(treelite_output_size_, num_output_class) << "Precondition violated";
UpdateInputShapes();
has_sparse_input_ = false;
if (!paths.metadata.empty() && !IsFileEmpty(paths.metadata)) {
LoadJsonFromFile(paths.metadata, this->metadata_);
ValidateDeviceTypeIfExists();
if (metadata_.count("Model") && metadata_["Model"].count("SparseInput")) {
has_sparse_input_ = metadata_["Model"]["SparseInput"].get<std::string>() == "1";
}
}
}