in lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc [144:263]
void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
const core::CDataFrame& frame,
std::size_t columnHoldingDependentVariable,
const TReadPredictionFunc& readClassProbabilities,
const TReadClassScoresFunc& readClassScores,
const TRowRef& row,
core::CBoostJsonConcurrentLineWriter& writer,
maths::analytics::CTreeShapFeatureImportance* featureImportance) const {
auto probabilities = readClassProbabilities(row);
auto scores = readClassScores(row);
double actualClassId{row[columnHoldingDependentVariable]};
std::size_t predictedClassId(std::max_element(scores.begin(), scores.end()) -
scores.begin());
const TStrVec& classValues{frame.categoricalColumnValues()[columnHoldingDependentVariable]};
writer.onObjectBegin();
writer.onKey(this->predictionFieldName());
writePredictedCategoryValue(classValues[predictedClassId], writer);
writer.onKey(PREDICTION_PROBABILITY_FIELD_NAME);
writer.onDouble(probabilities[predictedClassId]);
writer.onKey(PREDICTION_SCORE_FIELD_NAME);
writer.onDouble(scores[predictedClassId]);
writer.onKey(IS_TRAINING_FIELD_NAME);
writer.onBool(maths::analytics::CDataFrameUtils::isMissing(actualClassId) == false);
if (m_NumTopClasses != 0) {
TSizeVec classIds(scores.size());
std::iota(classIds.begin(), classIds.end(), 0);
std::sort(classIds.begin(), classIds.end(),
[&scores](std::size_t lhs, std::size_t rhs) {
return scores[lhs] > scores[rhs];
});
// -1 is a special value meaning "output all the classes"
classIds.resize(m_NumTopClasses == -1
? classIds.size()
: std::min(classIds.size(),
static_cast<std::size_t>(m_NumTopClasses)));
writer.onKey(TOP_CLASSES_FIELD_NAME);
writer.onArrayBegin();
for (std::size_t i : classIds) {
writer.onObjectBegin();
writer.onKey(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[i], writer);
writer.onKey(CLASS_PROBABILITY_FIELD_NAME);
writer.onDouble(probabilities[i]);
writer.onKey(CLASS_SCORE_FIELD_NAME);
writer.onDouble(scores[i]);
writer.onObjectEnd();
}
writer.onArrayEnd();
}
if (featureImportance != nullptr) {
int numberClasses{static_cast<int>(classValues.size())};
m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
m_InferenceModelMetadata.classValues(classValues);
m_InferenceModelMetadata.predictionFieldTypeResolverWriter(
[this](const std::string& categoryValue,
core::CBoostJsonConcurrentLineWriter& writer_) {
this->writePredictedCategoryValue(categoryValue, writer_);
});
featureImportance->shap(
row, [&](const maths::analytics::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::analytics::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.onKey(FEATURE_IMPORTANCE_FIELD_NAME);
writer.onArrayBegin();
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.onObjectBegin();
writer.onKey(FEATURE_NAME_FIELD_NAME);
writer.onString(featureNames[i]);
if (shap[i].size() == 1) {
// output feature importance for individual classes in binary case
writer.onKey(CLASSES_FIELD_NAME);
writer.onArrayBegin();
for (int j = 0; j < numberClasses; ++j) {
writer.onObjectBegin();
writer.onKey(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.onKey(IMPORTANCE_FIELD_NAME);
if (j == 1) {
writer.onDouble(shap[i](0));
} else {
writer.onDouble(-shap[i](0));
}
writer.onObjectEnd();
}
writer.onArrayEnd();
} else {
// output feature importance for individual classes in multiclass case
writer.onKey(CLASSES_FIELD_NAME);
writer.onArrayBegin();
for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) {
writer.onObjectBegin();
writer.onKey(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.onKey(IMPORTANCE_FIELD_NAME);
writer.onDouble(shap[i](j));
writer.onObjectEnd();
}
writer.onArrayEnd();
}
writer.onObjectEnd();
}
}
writer.onArrayEnd();
for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
});
}
writer.onObjectEnd();
}