inline treelite::Model ParseStream()

in src/frontend/xgboost.cc [320:443]


inline treelite::Model ParseStream(dmlc::Stream* fi) {
  std::vector<XGBTree> xgb_trees_;
  LearnerModelParam mparam_;    // model parameter
  GBTreeModelParam gbm_param_;  // GBTree training parameter
  std::string name_gbm_;
  std::string name_obj_;

  /* 1. Parse input stream */
  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
  // backward compatible header check.
  std::string header;
  header.resize(4);
  if (fp->PeekRead(&header[0], 4) == 4) {
    CHECK_NE(header, "bs64")
        << "Ill-formed XGBoost model file: Base64 format no longer supported";
    if (header == "binf") {
      CONSUME_BYTES(fp, 4);
    }
  }
  // read parameter
  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
      << "Ill-formed XGBoost model file: corrupted header";
  {
    // backward compatibility code for compatible with old model type
    // for new model, Read(&name_obj_) is suffice
    uint64_t len;
    CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
     << "Ill-formed XGBoost model file: corrupted header";
    if (len >= std::numeric_limits<unsigned>::max()) {
      int gap;
      CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
          << "Ill-formed XGBoost model file: corrupted header";
      len = len >> static_cast<uint64_t>(32UL);
    }
    if (len != 0) {
      name_obj_.resize(len);
      CHECK_EQ(fp->Read(&name_obj_[0], len), len)
          << "Ill-formed XGBoost model file: corrupted header";
    }
  }

  {
    uint64_t len;
    CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
      << "Ill-formed XGBoost model file: corrupted header";
    name_gbm_.resize(len);
    if (len > 0) {
      CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
        << "Ill-formed XGBoost model file: corrupted header";
    }
  }

  /* loading GBTree */
  CHECK_EQ(name_gbm_, "gbtree")
    << "Invalid XGBoost model file: "
    << "Gradient booster must be gbtree type.";

  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
    << "Invalid XGBoost model file: corrupted GBTree parameters";
  for (int i = 0; i < gbm_param_.num_trees; ++i) {
    xgb_trees_.emplace_back();
    xgb_trees_.back().Load(fp.get());
  }
  CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";

  /* 2. Export model */
  treelite::Model model;
  model.num_feature = gbm_param_.num_feature;
  model.num_output_group = gbm_param_.num_output_group;
  model.random_forest_flag = false;

  // set global bias
  model.param.global_bias = static_cast<float>(mparam_.base_score);

  // set correct prediction transform function, depending on objective function
  if (name_obj_ == "multi:softmax") {
    model.param.pred_transform = "max_index";
  } else if (name_obj_ == "multi:softprob") {
    model.param.pred_transform = "softmax";
  } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
    model.param.pred_transform = "sigmoid";
    model.param.sigmoid_alpha = 1.0f;
  } else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
             || name_obj_ == "reg:tweedie") {
    model.param.pred_transform = "exponential";
  } else {
    model.param.pred_transform = "identity";
  }

  // traverse trees
  for (const auto& xgb_tree : xgb_trees_) {
    model.trees.emplace_back();
    treelite::Tree& tree = model.trees.back();
    tree.Init();

    // assign node ID's so that a breadth-wise traversal would yield
    // the monotonic sequence 0, 1, 2, ...
    // deleted nodes will be excluded
    std::queue<std::pair<int, int>> Q;  // (old ID, new ID) pair
    Q.push({0, 0});
    while (!Q.empty()) {
      int old_id, new_id;
      std::tie(old_id, new_id) = Q.front(); Q.pop();
      const XGBTree::Node& node = xgb_tree[old_id];
      const NodeStat stat = xgb_tree.Stat(old_id);
      if (node.is_leaf()) {
        const bst_float leaf_value = node.leaf_value();
        tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
      } else {
        const bst_float split_cond = node.split_cond();
        tree.AddChilds(new_id);
        tree[new_id].set_numerical_split(node.split_index(),
                                   static_cast<treelite::tl_float>(split_cond),
                                   node.default_left(),
                                   treelite::Operator::kLT);
        tree[new_id].set_gain(stat.loss_chg);
        Q.push({node.cleft(), tree[new_id].cleft()});
        Q.push({node.cright(), tree[new_id].cright()});
      }
      tree[new_id].set_sum_hess(stat.sum_hess);
    }
  }
  return model;
}