CompiledModel Compile()

in src/compiler/failsafe.cc [267:369]


  CompiledModel Compile(const Model& model) override {
    CompiledModel cm;
    cm.backend = "native";

    num_feature_ = model.num_feature;
    num_output_group_ = model.num_output_group;
    CHECK(!model.random_forest_flag)
      << "Only gradient boosted trees supported in FailSafeCompiler";
    pred_tranform_func_ = PredTransformFunction("native", model);
    files_.clear();

    const char* predict_function_signature
      = (num_output_group_ > 1) ?
          "size_t predict_multiclass(union Entry* data, int pred_margin, "
                                    "float* result)"
        : "float predict(union Entry* data, int pred_margin)";

    std::ostringstream main_program;
    std::string accumulator_definition
      = (num_output_group_ > 1
         ? fmt::format("float sum[{num_output_group}] = {{0.0f}}",
             "num_output_group"_a = num_output_group_)
         : std::string("float sum = 0.0f"));

    std::string output_statement
      = (num_output_group_ > 1
         ? fmt::format("sum[tree_id % {num_output_group}] += tree[nid].info.leaf_value;",
             "num_output_group"_a = num_output_group_)
         : std::string("sum += tree[nid].info.leaf_value;"));

    std::string return_statement
      = (num_output_group_ > 1
         ? fmt::format(return_multiclass_template,
             "num_output_group"_a = num_output_group_,
             "global_bias"_a = common::ToStringHighPrecision(model.param.global_bias))
         : fmt::format(return_template,
             "global_bias"_a = common::ToStringHighPrecision(model.param.global_bias)));

    std::string nodes, nodes_row_ptr;
    std::vector<char> nodes_elf;
    if (param.dump_array_as_elf > 0) {
      if (param.verbose > 0) {
        LOG(INFO) << "Dumping arrays as an ELF relocatable object...";
      }
      std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
    } else {
      std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
    }

    main_program << fmt::format(main_template,
      "nodes_row_ptr"_a = nodes_row_ptr,
      "pred_transform_function"_a = pred_tranform_func_,
      "predict_function_signature"_a = predict_function_signature,
      "num_output_group"_a = num_output_group_,
      "num_feature"_a = num_feature_,
      "num_tree"_a = model.trees.size(),
      "compare_op"_a = GetCommonOp(model),
      "accumulator_definition"_a = accumulator_definition,
      "output_statement"_a = output_statement,
      "return_statement"_a = return_statement);

    files_["main.c"] = CompiledModel::FileEntry(main_program.str());

    if (param.dump_array_as_elf > 0) {
      files_["arrays.o"] = CompiledModel::FileEntry(std::move(nodes_elf));
    } else {
      files_["arrays.c"] = CompiledModel::FileEntry(fmt::format(arrays_template,
        "nodes"_a = nodes));
    }

    files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template,
      "dllexport"_a = DLLEXPORT_KEYWORD,
      "predict_function_signature"_a = predict_function_signature));

    {
      /* write recipe.json */
      std::vector<std::unordered_map<std::string, std::string>> source_list;
      std::vector<std::string> extra_file_list;
      for (const auto& kv : files_) {
        if (EndsWith(kv.first, ".c")) {
          const size_t line_count
            = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
          source_list.push_back({ {"name",
                                   kv.first.substr(0, kv.first.length() - 2)},
                                  {"length", std::to_string(line_count)} });
        } else if (EndsWith(kv.first, ".o")) {
          extra_file_list.push_back(kv.first);
        }
      }
      std::ostringstream oss;
      auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
      writer->BeginObject();
      writer->WriteObjectKeyValue("target", param.native_lib_name);
      writer->WriteObjectKeyValue("sources", source_list);
      if (!extra_file_list.empty()) {
        writer->WriteObjectKeyValue("extra", extra_file_list);
      }
      writer->EndObject();
      files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
    }
    cm.files = std::move(files_);
    return cm;
  }