inline treelite::Model ParseStream()

in src/frontend/lightgbm.cc [126:418]


inline treelite::Model ParseStream(dmlc::Stream* fi) {
  std::vector<LGBTree> lgb_trees_;
  int max_feature_idx_;
  int num_tree_per_iteration_;
  bool average_output_;
  std::string obj_name_;
  std::vector<std::string> obj_param_;

  /* 1. Parse input stream */
  std::vector<std::string> lines = LoadText(fi);
  std::unordered_map<std::string, std::string> global_dict;
  std::vector<std::unordered_map<std::string, std::string>> tree_dict;

  bool in_tree = false;  // is current entry part of a tree?
  for (const auto& line : lines) {
    std::istringstream ss(line);
    std::string key, value, rest;
    std::getline(ss, key, '=');
    std::getline(ss, value, '=');
    std::getline(ss, rest);
    CHECK(rest.empty()) << "Ill-formed LightGBM model file";
    if (key == "Tree") {
      in_tree = true;
      tree_dict.emplace_back();
    } else {
      if (in_tree) {
        tree_dict.back()[key] = value;
      } else {
        global_dict[key] = value;
      }
    }
  }

  {
    auto it = global_dict.find("objective");
    CHECK(it != global_dict.end())
      << "Ill-formed LightGBM model file: need objective";
    auto obj_strs = treelite::common::Split(it->second, ' ');
    obj_name_ = obj_strs[0];
    obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());

    it = global_dict.find("max_feature_idx");
    CHECK(it != global_dict.end())
      << "Ill-formed LightGBM model file: need max_feature_idx";
    max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
    it = global_dict.find("num_tree_per_iteration");
    CHECK(it != global_dict.end())
      << "Ill-formed LightGBM model file: need num_tree_per_iteration";
    num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);

    it = global_dict.find("average_output");
    average_output_ = (it != global_dict.end());
  }

  for (const auto& dict : tree_dict) {
    lgb_trees_.emplace_back();
    LGBTree& tree = lgb_trees_.back();

    auto it = dict.find("num_leaves");
    CHECK(it != dict.end())
      << "Ill-formed LightGBM model file: need num_leaves";
    tree.num_leaves = treelite::common::TextToNumber<int>(it->second);

    it = dict.find("num_cat");
    CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
    tree.num_cat = treelite::common::TextToNumber<int>(it->second);

    it = dict.find("leaf_value");
    CHECK(it != dict.end() && !it->second.empty())
      << "Ill-formed LightGBM model file: need leaf_value";
    tree.leaf_value
      = treelite::common::TextToArray<double>(it->second, tree.num_leaves);

    it = dict.find("decision_type");
    if (it == dict.end()) {
      tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
    } else {
      CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
        << "Ill-formed LightGBM model file: decision_type cannot be empty string";
      tree.decision_type
        = treelite::common::TextToArray<int8_t>(it->second,
                                                tree.num_leaves - 1);
    }

    if (tree.num_cat > 0) {
      it = dict.find("cat_boundaries");
      CHECK(it != dict.end() && !it->second.empty())
        << "Ill-formed LightGBM model file: need cat_boundaries";
      tree.cat_boundaries
        = treelite::common::TextToArray<uint64_t>(it->second, tree.num_cat + 1);
      it = dict.find("cat_threshold");
      CHECK(it != dict.end() && !it->second.empty())
        << "Ill-formed LightGBM model file: need cat_threshold";
      tree.cat_threshold
        = treelite::common::TextToArray<uint32_t>(it->second,
                                                  tree.cat_boundaries.back());
    }

    it = dict.find("split_feature");
    CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
      << "Ill-formed LightGBM model file: need split_feature";
    tree.split_feature
      = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);

    it = dict.find("threshold");
    CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
      << "Ill-formed LightGBM model file: need threshold";
    tree.threshold
      = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);

    it = dict.find("split_gain");
    if (it != dict.end()) {
      CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
        << "Ill-formed LightGBM model file: split_gain cannot be empty string";
      tree.split_gain
        = treelite::common::TextToArray<float>(it->second, tree.num_leaves - 1);
    } else {
      tree.split_gain.resize(tree.num_leaves - 1);
    }

    it = dict.find("internal_count");
    if (it != dict.end()) {
      CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
        << "Ill-formed LightGBM model file: internal_count cannot be empty string";
      tree.internal_count
        = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
    } else {
      tree.internal_count.resize(tree.num_leaves - 1);
    }

    it = dict.find("leaf_count");
    if (it != dict.end()) {
      CHECK(!it->second.empty())
        << "Ill-formed LightGBM model file: leaf_count cannot be empty string";
      tree.leaf_count
        = treelite::common::TextToArray<int>(it->second, tree.num_leaves);
    } else {
      tree.leaf_count.resize(tree.num_leaves);
    }

    it = dict.find("left_child");
    CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
      << "Ill-formed LightGBM model file: need left_child";
    tree.left_child
      = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);

    it = dict.find("right_child");
    CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
      << "Ill-formed LightGBM model file: need right_child";
    tree.right_child
      = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
  }

  /* 2. Export model */
  treelite::Model model;
  model.num_feature = max_feature_idx_ + 1;
  model.num_output_group = num_tree_per_iteration_;
  if (model.num_output_group > 1) {
    // multiclass classification with gradient boosted trees
    CHECK(!average_output_)
      << "Ill-formed LightGBM model file: cannot use random forest mode "
      << "for multi-class classification";
    model.random_forest_flag = false;
  } else {
    model.random_forest_flag = average_output_;
  }

  // set correct prediction transform function, depending on objective function
  if (obj_name_ == "multiclass") {
    // validate num_class parameter
    int num_class = -1;
    int tmp;
    for (const auto& str : obj_param_) {
      auto tokens = treelite::common::Split(str, ':');
      if (tokens.size() == 2 && tokens[0] == "num_class"
        && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
        num_class = tmp;
        break;
      }
    }
    CHECK(num_class >= 0 && num_class == model.num_output_group)
      << "Ill-formed LightGBM model file: not a valid multiclass objective";

    model.param.pred_transform = "softmax";
  } else if (obj_name_ == "multiclassova") {
    // validate num_class and alpha parameters
    int num_class = -1;
    float alpha = -1.0f;
    int tmp;
    float tmp2;
    for (const auto& str : obj_param_) {
      auto tokens = treelite::common::Split(str, ':');
      if (tokens.size() == 2) {
        if (tokens[0] == "num_class"
          && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
          num_class = tmp;
        } else if (tokens[0] == "sigmoid"
         && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
          alpha = tmp2;
        }
      }
    }
    CHECK(num_class >= 0 && num_class == model.num_output_group
          && alpha > 0.0f)
      << "Ill-formed LightGBM model file: not a valid multiclassova objective";

    model.param.pred_transform = "multiclass_ova";
    model.param.sigmoid_alpha = alpha;
  } else if (obj_name_ == "binary") {
    // validate alpha parameter
    float alpha = -1.0f;
    float tmp;
    for (const auto& str : obj_param_) {
      auto tokens = treelite::common::Split(str, ':');
      if (tokens.size() == 2 && tokens[0] == "sigmoid"
        && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
        alpha = tmp;
        break;
      }
    }
    CHECK_GT(alpha, 0.0f)
      << "Ill-formed LightGBM model file: not a valid binary objective";

    model.param.pred_transform = "sigmoid";
    model.param.sigmoid_alpha = alpha;
  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
    model.param.pred_transform = "sigmoid";
    model.param.sigmoid_alpha = 1.0f;
  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
    model.param.pred_transform = "logarithm_one_plus_exp";
  } else {
    model.param.pred_transform = "identity";
  }

  // traverse trees
  for (const auto& lgb_tree : lgb_trees_) {
    model.trees.emplace_back();
    treelite::Tree& tree = model.trees.back();
    tree.Init();

    // assign node ID's so that a breadth-wise traversal would yield
    // the monotonic sequence 0, 1, 2, ...
    std::queue<std::pair<int, int>> Q;  // (old ID, new ID) pair
    Q.push({0, 0});
    while (!Q.empty()) {
      int old_id, new_id;
      std::tie(old_id, new_id) = Q.front(); Q.pop();
      if (old_id < 0) {  // leaf
        const double leaf_value = lgb_tree.leaf_value[~old_id];
        const int data_count = lgb_tree.leaf_count[~old_id];
        tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
        CHECK_GE(data_count, 0);
        tree[new_id].set_data_count(static_cast<size_t>(data_count));
      } else {  // non-leaf
        const int data_count = lgb_tree.internal_count[old_id];
        const unsigned split_index =
          static_cast<unsigned>(lgb_tree.split_feature[old_id]);

        tree.AddChilds(new_id);
        if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
          // categorical
          const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
          const std::vector<uint32_t> left_categories
            = BitsetToList(lgb_tree.cat_threshold.data()
                             + lgb_tree.cat_boundaries[cat_idx],
                           lgb_tree.cat_boundaries[cat_idx + 1]
                             - lgb_tree.cat_boundaries[cat_idx]);
          const auto missing_type
            = GetMissingType(lgb_tree.decision_type[old_id]);
          tree[new_id].set_categorical_split(split_index, false,
                                             (missing_type != MissingType::kNaN),
                                             left_categories);
        } else {
          // numerical
          const treelite::tl_float threshold =
            static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
          const bool default_left
            = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
          const treelite::Operator cmp_op = treelite::Operator::kLE;
          tree[new_id].set_numerical_split(split_index, threshold,
                                           default_left, cmp_op);
        }
        CHECK_GE(data_count, 0);
        tree[new_id].set_data_count(static_cast<size_t>(data_count));
        tree[new_id].set_gain(static_cast<double>(lgb_tree.split_gain[old_id]));
        Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
        Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
      }
    }
  }
  LOG(INFO) << "model.num_tree = " << model.trees.size();
  return model;
}