void ExportProtobufModel()

in src/frontend/protobuf.cc [218:324]


void ExportProtobufModel(const char* filename, const Model& model) {
  GOOGLE_PROTOBUF_VERIFY_VERSION;

  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "w"));
  dmlc::ostream os(fi.get());
  treelite_protobuf::Model protomodel;

  protomodel.set_num_feature(
    static_cast<google::protobuf::int32>(model.num_feature));

  protomodel.set_num_output_group(
    static_cast<google::protobuf::int32>(model.num_output_group));

  protomodel.set_random_forest_flag(model.random_forest_flag);

  // extra parameters field
  for (const auto& kv : model.param.__DICT__()) {
    (*protomodel.mutable_extra_params())[kv.first] = kv.second;
  }

  // flag to check consistent use of leaf vector
  // 0: no leaf should use leaf vector
  // 1: every leaf should use leaf vector
  // -1: indeterminate
  int8_t flag_leaf_vector = -1;

  const int ntree = model.trees.size();
  for (int i = 0; i < ntree; ++i) {
    const Tree& tree = model.trees[i];
    treelite_protobuf::Tree* proto_tree = protomodel.add_trees();

    std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
    Q.push({0, proto_tree->mutable_head()});
    while (!Q.empty()) {
      auto elem = Q.front(); Q.pop();
      const int nid = elem.first;
      treelite_protobuf::Node* proto_node = elem.second;
      if (tree[nid].is_leaf()) {  // leaf node
        if (tree[nid].has_leaf_vector()) {  // leaf node with vector output
          CHECK(flag_leaf_vector != 0)
            << "Inconsistent use of leaf vector: if one leaf node uses "
            << "a leaf vector, *every* leaf node must use a leaf vector as well";
          flag_leaf_vector = 1;  // now every leaf must use leaf vector

          const auto& leaf_vector = tree[nid].leaf_vector();
          CHECK_EQ(leaf_vector.size(), model.num_output_group)
            << "The length of leaf vector must be identical to the "
            << "number of output groups";
          for (tl_float e : leaf_vector) {
            proto_node->add_leaf_vector(static_cast<double>(e));
          }
          CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
        } else {  // leaf node with scalar output
          CHECK(flag_leaf_vector != 1)
            << "Inconsistent use of leaf vector: if one leaf node does not use"
            << "a leaf vector, *no other* leaf node can use a leaf vector";
          flag_leaf_vector = 0;  // now no leaf can use leaf vector

          proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
        }
      } else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
        // numerical split
        const unsigned split_index = tree[nid].split_index();
        const tl_float threshold = tree[nid].threshold();
        const bool default_left = tree[nid].default_left();
        const Operator op = tree[nid].comparison_op();

        proto_node->set_default_left(default_left);
        proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
        proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
        proto_node->set_op(OpName(op));
        proto_node->set_threshold(static_cast<double>(threshold));
        Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
        Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
      } else {  // categorical split
        const unsigned split_index = tree[nid].split_index();
        const auto& left_categories = tree[nid].left_categories();
        const bool default_left = tree[nid].default_left();
        const bool missing_category_to_zero = tree[nid].missing_category_to_zero();

        proto_node->set_default_left(default_left);
        proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
        proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
        proto_node->set_missing_category_to_zero(missing_category_to_zero);
        for (auto e : left_categories) {
          proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
        }
        Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
        Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
      }
      /* set node statistics */
      if (tree[nid].has_data_count()) {
        proto_node->set_data_count(
          static_cast<google::protobuf::uint64>(tree[nid].data_count()));
      }
      if (tree[nid].has_sum_hess()) {
        proto_node->set_sum_hess(tree[nid].sum_hess());
      }
      if (tree[nid].has_gain()) {
        proto_node->set_gain(tree[nid].gain());
      }
    }
  }
  CHECK(protomodel.SerializeToOstream(&os))
    << "Failed to write Protocol Buffers file";
  os.set_stream(nullptr);
}