void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow()

in lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc [144:263]


void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
    const core::CDataFrame& frame,
    std::size_t columnHoldingDependentVariable,
    const TReadPredictionFunc& readClassProbabilities,
    const TReadClassScoresFunc& readClassScores,
    const TRowRef& row,
    core::CBoostJsonConcurrentLineWriter& writer,
    maths::analytics::CTreeShapFeatureImportance* featureImportance) const {

    auto probabilities = readClassProbabilities(row);
    auto scores = readClassScores(row);

    double actualClassId{row[columnHoldingDependentVariable]};
    std::size_t predictedClassId(std::max_element(scores.begin(), scores.end()) -
                                 scores.begin());

    const TStrVec& classValues{frame.categoricalColumnValues()[columnHoldingDependentVariable]};
    writer.onObjectBegin();
    writer.onKey(this->predictionFieldName());
    writePredictedCategoryValue(classValues[predictedClassId], writer);
    writer.onKey(PREDICTION_PROBABILITY_FIELD_NAME);
    writer.onDouble(probabilities[predictedClassId]);
    writer.onKey(PREDICTION_SCORE_FIELD_NAME);
    writer.onDouble(scores[predictedClassId]);
    writer.onKey(IS_TRAINING_FIELD_NAME);
    writer.onBool(maths::analytics::CDataFrameUtils::isMissing(actualClassId) == false);

    if (m_NumTopClasses != 0) {
        TSizeVec classIds(scores.size());
        std::iota(classIds.begin(), classIds.end(), 0);
        std::sort(classIds.begin(), classIds.end(),
                  [&scores](std::size_t lhs, std::size_t rhs) {
                      return scores[lhs] > scores[rhs];
                  });
        // -1 is a special value meaning "output all the classes"
        classIds.resize(m_NumTopClasses == -1
                            ? classIds.size()
                            : std::min(classIds.size(),
                                       static_cast<std::size_t>(m_NumTopClasses)));
        writer.onKey(TOP_CLASSES_FIELD_NAME);
        writer.onArrayBegin();
        for (std::size_t i : classIds) {
            writer.onObjectBegin();
            writer.onKey(CLASS_NAME_FIELD_NAME);
            writePredictedCategoryValue(classValues[i], writer);
            writer.onKey(CLASS_PROBABILITY_FIELD_NAME);
            writer.onDouble(probabilities[i]);
            writer.onKey(CLASS_SCORE_FIELD_NAME);
            writer.onDouble(scores[i]);
            writer.onObjectEnd();
        }
        writer.onArrayEnd();
    }

    if (featureImportance != nullptr) {
        int numberClasses{static_cast<int>(classValues.size())};
        m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
        m_InferenceModelMetadata.classValues(classValues);
        m_InferenceModelMetadata.predictionFieldTypeResolverWriter(
            [this](const std::string& categoryValue,
                   core::CBoostJsonConcurrentLineWriter& writer_) {
                this->writePredictedCategoryValue(categoryValue, writer_);
            });
        featureImportance->shap(
            row, [&](const maths::analytics::CTreeShapFeatureImportance::TSizeVec& indices,
                     const TStrVec& featureNames,
                     const maths::analytics::CTreeShapFeatureImportance::TVectorVec& shap) {
                writer.onKey(FEATURE_IMPORTANCE_FIELD_NAME);
                writer.onArrayBegin();
                for (auto i : indices) {
                    if (shap[i].norm() != 0.0) {
                        writer.onObjectBegin();
                        writer.onKey(FEATURE_NAME_FIELD_NAME);
                        writer.onString(featureNames[i]);
                        if (shap[i].size() == 1) {
                            // output feature importance for individual classes in binary case
                            writer.onKey(CLASSES_FIELD_NAME);
                            writer.onArrayBegin();
                            for (int j = 0; j < numberClasses; ++j) {
                                writer.onObjectBegin();
                                writer.onKey(CLASS_NAME_FIELD_NAME);
                                writePredictedCategoryValue(classValues[j], writer);
                                writer.onKey(IMPORTANCE_FIELD_NAME);
                                if (j == 1) {
                                    writer.onDouble(shap[i](0));
                                } else {
                                    writer.onDouble(-shap[i](0));
                                }
                                writer.onObjectEnd();
                            }
                            writer.onArrayEnd();
                        } else {
                            // output feature importance for individual classes in multiclass case
                            writer.onKey(CLASSES_FIELD_NAME);
                            writer.onArrayBegin();
                            for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) {
                                writer.onObjectBegin();
                                writer.onKey(CLASS_NAME_FIELD_NAME);
                                writePredictedCategoryValue(classValues[j], writer);
                                writer.onKey(IMPORTANCE_FIELD_NAME);
                                writer.onDouble(shap[i](j));
                                writer.onObjectEnd();
                            }
                            writer.onArrayEnd();
                        }
                        writer.onObjectEnd();
                    }
                }
                writer.onArrayEnd();

                for (std::size_t i = 0; i < shap.size(); ++i) {
                    if (shap[i].lpNorm<1>() != 0) {
                        const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
                            ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
                    }
                }
            });
    }
    writer.onObjectEnd();
}