in src/boosting/gbdt_model_text.cpp [421:625]
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
// use serialized string to restore this object
models_.clear();
auto c_str = buffer;
auto p = c_str;
auto end = p + len;
std::unordered_map<std::string, std::string> key_vals;
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (!Common::StartsWith(cur_line, "Tree=")) {
auto strs = Common::Split(cur_line.c_str(), '=');
if (strs.size() == 1) {
key_vals[strs[0]] = "";
} else if (strs.size() == 2) {
key_vals[strs[0]] = strs[1];
} else if (strs.size() > 2) {
if (strs[0] == "feature_names") {
key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
} else if (strs[0] == "monotone_constraints") {
key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
} else {
// Use first 128 chars to avoid exceed the message buffer.
Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
}
}
} else {
break;
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
// get number of classes
if (key_vals.count("num_class")) {
Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
} else {
Log::Fatal("Model file doesn't specify the number of classes");
return false;
}
if (key_vals.count("num_tree_per_iteration")) {
Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
} else {
num_tree_per_iteration_ = num_class_;
}
// get index of label
if (key_vals.count("label_index")) {
Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
} else {
Log::Fatal("Model file doesn't specify the label index");
return false;
}
// get max_feature_idx first
if (key_vals.count("max_feature_idx")) {
Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
} else {
Log::Fatal("Model file doesn't specify max_feature_idx");
return false;
}
// get average_output
if (key_vals.count("average_output")) {
average_output_ = true;
}
// get feature names
if (key_vals.count("feature_names")) {
feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names");
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature_names");
return false;
}
// get monotone_constraints
if (key_vals.count("monotone_constraints")) {
monotone_constraints_ = CommonC::StringToArray<int8_t>(key_vals["monotone_constraints"].c_str(), ' ');
if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of monotone_constraints");
return false;
}
}
if (key_vals.count("feature_infos")) {
feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_infos");
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature_infos");
return false;
}
if (key_vals.count("objective")) {
auto str = key_vals["objective"];
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(ParseObjectiveAlias(str)));
objective_function_ = loaded_objective_.get();
}
if (!key_vals.count("tree_sizes")) {
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (Common::StartsWith(cur_line, "Tree=")) {
p += line_len;
p = Common::SkipNewLine(p);
size_t used_len = 0;
models_.emplace_back(new Tree(p, &used_len));
p += used_len;
} else {
break;
}
}
p = Common::SkipNewLine(p);
}
} else {
std::vector<size_t> tree_sizes = CommonC::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
int num_trees = static_cast<int>(tree_sizes.size());
for (int i = 0; i < num_trees; ++i) {
tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
models_.emplace_back(nullptr);
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_trees; ++i) {
OMP_LOOP_EX_BEGIN();
auto cur_p = p + tree_boundries[i];
auto line_len = Common::GetLine(cur_p);
std::string cur_line(cur_p, line_len);
if (Common::StartsWith(cur_line, "Tree=")) {
cur_p += line_len;
cur_p = Common::SkipNewLine(cur_p);
size_t used_len = 0;
models_[i].reset(new Tree(cur_p, &used_len));
} else {
Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
}
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
bool is_inparameter = false, is_inparser = false;
std::stringstream ss;
Common::C_stringstream(ss);
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parameters:")) {
is_inparameter = true;
} else if (cur_line == std::string("end of parameters")) {
break;
} else if (is_inparameter) {
ss << cur_line << "\n";
if (Common::StartsWith(cur_line, "[linear_tree: ")) {
int is_linear = 0;
Common::Atoi(cur_line.substr(14, 1).c_str(), &is_linear);
linear_tree_ = static_cast<bool>(is_linear);
}
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
if (!ss.str().empty()) {
loaded_parameter_ = ss.str();
}
ss.clear();
ss.str("");
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parser:")) {
is_inparser = true;
} else if (cur_line == std::string("end of parser")) {
p += line_len;
p = Common::SkipNewLine(p);
break;
} else if (is_inparser) {
ss << cur_line << "\n";
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
parser_config_str_ = ss.str();
ss.clear();
ss.str("");
return true;
}