in src/compiler/ast_native.cc [172:234]
void HandleMainNode(const MainNode* node,
const std::string& dest,
size_t indent) {
const char* get_num_output_group_function_signature
= "size_t get_num_output_group(void)";
const char* get_num_feature_function_signature
= "size_t get_num_feature(void)";
const char* predict_function_signature
= (num_output_group_ > 1) ?
"size_t predict_multiclass(union Entry* data, int pred_margin, "
"float* result)"
: "float predict(union Entry* data, int pred_margin)";
if (!array_is_categorical_.empty()) {
array_is_categorical_
= fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
array_is_categorical_);
}
AppendToBuffer(dest,
fmt::format(native::main_start_template,
"array_is_categorical"_a = array_is_categorical_,
"get_num_output_group_function_signature"_a
= get_num_output_group_function_signature,
"get_num_feature_function_signature"_a
= get_num_feature_function_signature,
"pred_transform_function"_a = pred_tranform_func_,
"predict_function_signature"_a = predict_function_signature,
"num_output_group"_a = num_output_group_,
"num_feature"_a = node->num_feature),
indent);
AppendToBuffer("header.h",
fmt::format(native::header_template,
"dllexport"_a = DLLEXPORT_KEYWORD,
"get_num_output_group_function_signature"_a
= get_num_output_group_function_signature,
"get_num_feature_function_signature"_a
= get_num_feature_function_signature,
"predict_function_signature"_a = predict_function_signature,
"threshold_type"_a = (param.quantize > 0 ? "int" : "double")),
indent);
CHECK_EQ(node->children.size(), 1);
WalkAST(node->children[0], dest, indent + 2);
const std::string optional_average_field
= (node->average_result) ? fmt::format(" / {}", node->num_tree)
: std::string("");
if (num_output_group_ > 1) {
AppendToBuffer(dest,
fmt::format(native::main_end_multiclass_template,
"num_output_group"_a = num_output_group_,
"optional_average_field"_a = optional_average_field,
"global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
indent);
} else {
AppendToBuffer(dest,
fmt::format(native::main_end_template,
"optional_average_field"_a = optional_average_field,
"global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
indent);
}
}