in src/frontend/xgboost.cc [320:443]
inline treelite::Model ParseStream(dmlc::Stream* fi) {
std::vector<XGBTree> xgb_trees_;
LearnerModelParam mparam_; // model parameter
GBTreeModelParam gbm_param_; // GBTree training parameter
std::string name_gbm_;
std::string name_obj_;
/* 1. Parse input stream */
std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
// backward compatible header check.
std::string header;
header.resize(4);
if (fp->PeekRead(&header[0], 4) == 4) {
CHECK_NE(header, "bs64")
<< "Ill-formed XGBoost model file: Base64 format no longer supported";
if (header == "binf") {
CONSUME_BYTES(fp, 4);
}
}
// read parameter
CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
<< "Ill-formed XGBoost model file: corrupted header";
{
// backward compatibility code for compatible with old model type
// for new model, Read(&name_obj_) is suffice
uint64_t len;
CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
<< "Ill-formed XGBoost model file: corrupted header";
if (len >= std::numeric_limits<unsigned>::max()) {
int gap;
CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
<< "Ill-formed XGBoost model file: corrupted header";
len = len >> static_cast<uint64_t>(32UL);
}
if (len != 0) {
name_obj_.resize(len);
CHECK_EQ(fp->Read(&name_obj_[0], len), len)
<< "Ill-formed XGBoost model file: corrupted header";
}
}
{
uint64_t len;
CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
<< "Ill-formed XGBoost model file: corrupted header";
name_gbm_.resize(len);
if (len > 0) {
CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
<< "Ill-formed XGBoost model file: corrupted header";
}
}
/* loading GBTree */
CHECK_EQ(name_gbm_, "gbtree")
<< "Invalid XGBoost model file: "
<< "Gradient booster must be gbtree type.";
CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
<< "Invalid XGBoost model file: corrupted GBTree parameters";
for (int i = 0; i < gbm_param_.num_trees; ++i) {
xgb_trees_.emplace_back();
xgb_trees_.back().Load(fp.get());
}
CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
/* 2. Export model */
treelite::Model model;
model.num_feature = gbm_param_.num_feature;
model.num_output_group = gbm_param_.num_output_group;
model.random_forest_flag = false;
// set global bias
model.param.global_bias = static_cast<float>(mparam_.base_score);
// set correct prediction transform function, depending on objective function
if (name_obj_ == "multi:softmax") {
model.param.pred_transform = "max_index";
} else if (name_obj_ == "multi:softprob") {
model.param.pred_transform = "softmax";
} else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
model.param.pred_transform = "sigmoid";
model.param.sigmoid_alpha = 1.0f;
} else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
|| name_obj_ == "reg:tweedie") {
model.param.pred_transform = "exponential";
} else {
model.param.pred_transform = "identity";
}
// traverse trees
for (const auto& xgb_tree : xgb_trees_) {
model.trees.emplace_back();
treelite::Tree& tree = model.trees.back();
tree.Init();
// assign node ID's so that a breadth-wise traversal would yield
// the monotonic sequence 0, 1, 2, ...
// deleted nodes will be excluded
std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
Q.push({0, 0});
while (!Q.empty()) {
int old_id, new_id;
std::tie(old_id, new_id) = Q.front(); Q.pop();
const XGBTree::Node& node = xgb_tree[old_id];
const NodeStat stat = xgb_tree.Stat(old_id);
if (node.is_leaf()) {
const bst_float leaf_value = node.leaf_value();
tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
} else {
const bst_float split_cond = node.split_cond();
tree.AddChilds(new_id);
tree[new_id].set_numerical_split(node.split_index(),
static_cast<treelite::tl_float>(split_cond),
node.default_left(),
treelite::Operator::kLT);
tree[new_id].set_gain(stat.loss_chg);
Q.push({node.cleft(), tree[new_id].cleft()});
Q.push({node.cright(), tree[new_id].cright()});
}
tree[new_id].set_sum_hess(stat.sum_hess);
}
}
return model;
}