in src/frontend/protobuf.cc [218:324]
void ExportProtobufModel(const char* filename, const Model& model) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "w"));
dmlc::ostream os(fi.get());
treelite_protobuf::Model protomodel;
protomodel.set_num_feature(
static_cast<google::protobuf::int32>(model.num_feature));
protomodel.set_num_output_group(
static_cast<google::protobuf::int32>(model.num_output_group));
protomodel.set_random_forest_flag(model.random_forest_flag);
// extra parameters field
for (const auto& kv : model.param.__DICT__()) {
(*protomodel.mutable_extra_params())[kv.first] = kv.second;
}
// flag to check consistent use of leaf vector
// 0: no leaf should use leaf vector
// 1: every leaf should use leaf vector
// -1: indeterminate
int8_t flag_leaf_vector = -1;
const int ntree = model.trees.size();
for (int i = 0; i < ntree; ++i) {
const Tree& tree = model.trees[i];
treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
Q.push({0, proto_tree->mutable_head()});
while (!Q.empty()) {
auto elem = Q.front(); Q.pop();
const int nid = elem.first;
treelite_protobuf::Node* proto_node = elem.second;
if (tree[nid].is_leaf()) { // leaf node
if (tree[nid].has_leaf_vector()) { // leaf node with vector output
CHECK(flag_leaf_vector != 0)
<< "Inconsistent use of leaf vector: if one leaf node uses "
<< "a leaf vector, *every* leaf node must use a leaf vector as well";
flag_leaf_vector = 1; // now every leaf must use leaf vector
const auto& leaf_vector = tree[nid].leaf_vector();
CHECK_EQ(leaf_vector.size(), model.num_output_group)
<< "The length of leaf vector must be identical to the "
<< "number of output groups";
for (tl_float e : leaf_vector) {
proto_node->add_leaf_vector(static_cast<double>(e));
}
CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
} else { // leaf node with scalar output
CHECK(flag_leaf_vector != 1)
<< "Inconsistent use of leaf vector: if one leaf node does not use"
<< "a leaf vector, *no other* leaf node can use a leaf vector";
flag_leaf_vector = 0; // now no leaf can use leaf vector
proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
}
} else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
// numerical split
const unsigned split_index = tree[nid].split_index();
const tl_float threshold = tree[nid].threshold();
const bool default_left = tree[nid].default_left();
const Operator op = tree[nid].comparison_op();
proto_node->set_default_left(default_left);
proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
proto_node->set_op(OpName(op));
proto_node->set_threshold(static_cast<double>(threshold));
Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
} else { // categorical split
const unsigned split_index = tree[nid].split_index();
const auto& left_categories = tree[nid].left_categories();
const bool default_left = tree[nid].default_left();
const bool missing_category_to_zero = tree[nid].missing_category_to_zero();
proto_node->set_default_left(default_left);
proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
proto_node->set_missing_category_to_zero(missing_category_to_zero);
for (auto e : left_categories) {
proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
}
Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
}
/* set node statistics */
if (tree[nid].has_data_count()) {
proto_node->set_data_count(
static_cast<google::protobuf::uint64>(tree[nid].data_count()));
}
if (tree[nid].has_sum_hess()) {
proto_node->set_sum_hess(tree[nid].sum_hess());
}
if (tree[nid].has_gain()) {
proto_node->set_gain(tree[nid].gain());
}
}
}
CHECK(protomodel.SerializeToOstream(&os))
<< "Failed to write Protocol Buffers file";
os.set_stream(nullptr);
}