void HandleMainNode()

in src/compiler/ast_native.cc [172:234]


  void HandleMainNode(const MainNode* node,
                      const std::string& dest,
                      size_t indent) {
    const char* get_num_output_group_function_signature
      = "size_t get_num_output_group(void)";
    const char* get_num_feature_function_signature
      = "size_t get_num_feature(void)";
    const char* predict_function_signature
      = (num_output_group_ > 1) ?
          "size_t predict_multiclass(union Entry* data, int pred_margin, "
                                    "float* result)"
        : "float predict(union Entry* data, int pred_margin)";

    if (!array_is_categorical_.empty()) {
      array_is_categorical_
        = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
                      array_is_categorical_);
    }

    AppendToBuffer(dest,
      fmt::format(native::main_start_template,
        "array_is_categorical"_a = array_is_categorical_,
        "get_num_output_group_function_signature"_a
          = get_num_output_group_function_signature,
        "get_num_feature_function_signature"_a
          = get_num_feature_function_signature,
        "pred_transform_function"_a = pred_tranform_func_,
        "predict_function_signature"_a = predict_function_signature,
        "num_output_group"_a = num_output_group_,
        "num_feature"_a = node->num_feature),
      indent);
    AppendToBuffer("header.h",
      fmt::format(native::header_template,
        "dllexport"_a = DLLEXPORT_KEYWORD,
        "get_num_output_group_function_signature"_a
          = get_num_output_group_function_signature,
        "get_num_feature_function_signature"_a
          = get_num_feature_function_signature,
        "predict_function_signature"_a = predict_function_signature,
        "threshold_type"_a = (param.quantize > 0 ? "int" : "double")),
      indent);

    CHECK_EQ(node->children.size(), 1);
    WalkAST(node->children[0], dest, indent + 2);

    const std::string optional_average_field
      = (node->average_result) ? fmt::format(" / {}", node->num_tree)
                               : std::string("");
    if (num_output_group_ > 1) {
      AppendToBuffer(dest,
        fmt::format(native::main_end_multiclass_template,
          "num_output_group"_a = num_output_group_,
          "optional_average_field"_a = optional_average_field,
          "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
        indent);
    } else {
      AppendToBuffer(dest,
        fmt::format(native::main_end_template,
          "optional_average_field"_a = optional_average_field,
          "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
        indent);
    }
  }