in src/frontend/protobuf.cc [70:216]
Model LoadProtobufModel(const char* filename) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
dmlc::istream is(fi.get());
treelite_protobuf::Model protomodel;
CHECK(protomodel.ParseFromIstream(&is)) << "Ill-formed Protocol Buffers file";
Model model;
CHECK(protomodel.has_num_feature()) << "num_feature must exist";
const auto num_feature = protomodel.num_feature();
CHECK_LT(num_feature, std::numeric_limits<int>::max())
<< "num_feature too big";
CHECK_GT(num_feature, 0) << "num_feature must be positive";
model.num_feature = static_cast<int>(protomodel.num_feature());
CHECK(protomodel.has_num_output_group()) << "num_output_group must exist";
const auto num_output_group = protomodel.num_output_group();
CHECK_LT(num_output_group, std::numeric_limits<int>::max())
<< "num_output_group too big";
CHECK_GT(num_output_group, 0) << "num_output_group must be positive";
model.num_output_group = static_cast<int>(protomodel.num_output_group());
CHECK(protomodel.has_random_forest_flag())
<< "random_forest_flag must exist";
model.random_forest_flag = protomodel.random_forest_flag();
// extra parameters field
const auto& ep = protomodel.extra_params();
std::vector<std::pair<std::string, std::string>> cfg;
std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
InitParamAndCheck(&model.param, cfg);
// flag to check consistent use of leaf vector
// 0: no leaf should use leaf vector
// 1: every leaf should use leaf vector
// -1: indeterminate
int8_t flag_leaf_vector = -1;
const int ntree = protomodel.trees_size();
for (int i = 0; i < ntree; ++i) {
model.trees.emplace_back();
Tree& tree = model.trees.back();
tree.Init();
CHECK(protomodel.trees(i).has_head());
// assign node ID's so that a breadth-wise traversal would yield
// the monotonic sequence 0, 1, 2, ...
std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
// (proto node, ID)
Q.push({protomodel.trees(i).head(), 0});
while (!Q.empty()) {
auto elem = Q.front(); Q.pop();
const treelite_protobuf::Node& node = elem.first;
int id = elem.second;
const NodeType node_type = GetNodeType(node);
if (node_type == NodeType::kLeaf) { // leaf node with a scalar output
CHECK(flag_leaf_vector != 1)
<< "Inconsistent use of leaf vector: if one leaf node does not use"
<< "a leaf vector, *no other* leaf node can use a leaf vector";
flag_leaf_vector = 0; // now no leaf can use leaf vector
tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
} else if (node_type == NodeType::kLeafVector) {
// leaf node with vector output
CHECK(flag_leaf_vector != 0)
<< "Inconsistent use of leaf vector: if one leaf node uses "
<< "a leaf vector, *every* leaf node must use a leaf vector as well";
flag_leaf_vector = 1; // now every leaf must use leaf vector
const int len = node.leaf_vector_size();
CHECK_EQ(len, model.num_output_group)
<< "The length of leaf vector must be identical to the "
<< "number of output groups";
std::vector<tl_float> leaf_vector(len);
for (int i = 0; i < len; ++i) {
leaf_vector[i] = static_cast<tl_float>(node.leaf_vector(i));
}
tree[id].set_leaf_vector(leaf_vector);
} else if (node_type == NodeType::kNumericalSplit) { // numerical split
const auto split_index = node.split_index();
const std::string opname = node.op();
CHECK_LT(split_index, model.num_feature)
<< "split_index must be between 0 and [num_feature] - 1.";
CHECK_GE(split_index, 0) << "split_index must be positive.";
CHECK_GT(optable.count(opname), 0) << "No operator `"
<< opname << "\" exists";
tree.AddChilds(id);
tree[id].set_numerical_split(static_cast<unsigned>(split_index),
static_cast<tl_float>(node.threshold()),
node.default_left(),
optable.at(opname.c_str()));
Q.push({node.left_child(), tree[id].cleft()});
Q.push({node.right_child(), tree[id].cright()});
} else { // categorical split
const auto split_index = node.split_index();
CHECK_LT(split_index, model.num_feature)
<< "split_index must be between 0 and [num_feature] - 1.";
CHECK_GE(split_index, 0) << "split_index must be positive.";
const int left_categories_size = node.left_categories_size();
std::vector<uint32_t> left_categories;
for (int i = 0; i < left_categories_size; ++i) {
const auto cat = node.left_categories(i);
CHECK(cat <= std::numeric_limits<uint32_t>::max());
left_categories.push_back(static_cast<uint32_t>(cat));
}
tree.AddChilds(id);
tree[id].set_categorical_split(static_cast<unsigned>(split_index),
node.default_left(),
node.missing_category_to_zero(),
left_categories);
Q.push({node.left_child(), tree[id].cleft()});
Q.push({node.right_child(), tree[id].cright()});
}
/* set node statistics */
if (node.has_data_count()) {
tree[id].set_data_count(static_cast<size_t>(node.data_count()));
}
if (node.has_sum_hess()) {
tree[id].set_sum_hess(node.sum_hess());
}
if (node.has_gain()) {
tree[id].set_gain(node.gain());
}
}
}
if (flag_leaf_vector == 0) {
if (model.num_output_group > 1) {
// multiclass classification with gradient boosted trees
CHECK(!model.random_forest_flag)
<< "To use a random forest for multi-class classification, each leaf "
<< "node must output a leaf vector specifying a probability "
<< "distribution";
CHECK_EQ(ntree % model.num_output_group, 0)
<< "For multi-class classifiers with gradient boosted trees, the number "
<< "of trees must be evenly divisible by the number of output groups";
}
} else if (flag_leaf_vector == 1) {
// multiclass classification with a random forest
CHECK(model.random_forest_flag)
<< "In multi-class classifiers with gradient boosted trees, each leaf "
<< "node must output a single floating-point value.";
} else {
LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
}
return model;
}