in tensorflow_model_analysis/frontend/tfma-multi-class-confusion-matrix-at-thresholds/tfma-multi-class-confusion-matrix-at-thresholds.js [453:530]
computeSummaryMatrix_(entries, multiLabel, classIds) {
if (!entries || !classIds) {
return undefined;
}
const matrix = {};
for (const entry of entries) {
const actual = entry[FieldNames.ACTUAL_CLASS_ID] || 0;
const predicted = entry[FieldNames.PREDICTED_CLASS_ID] || 0;
if (!matrix[actual]) {
matrix[actual] = {entries: {}};
}
const row = matrix[actual].entries;
const isDiagonal = predicted == actual;
const truePositives = entry[FieldNames.TRUE_POSITIVES] || 0;
const falsePositives = entry[FieldNames.FALSE_POSITIVES] || 0;
const numWeightedExamples = entry[FieldNames.NUM_WEIGHTED_EXAMPLES] || 0;
const truePositivesToUse =
multiLabel ? truePositives : (isDiagonal ? numWeightedExamples : 0);
const falsePositivesToUse =
multiLabel ? falsePositives : (isDiagonal ? 0 : numWeightedExamples);
const falseNegatives = entry[FieldNames.FALSE_NEGATIVES] || 0;
row[predicted] = {
positives: truePositivesToUse + falsePositivesToUse,
truePositives: truePositivesToUse,
falsePositives: falsePositivesToUse,
falseNegatives: falseNegatives,
};
}
classIds.forEach((rowId) => {
if (rowId == NO_PREDICTION_CLASS_ID) {
// Do not create a row for no prediction.
return;
}
// Fill in holes in the matrix.
if (!matrix[rowId]) {
matrix[rowId] = {entries: {}};
}
const currentRow = matrix[rowId].entries;
let totalPositives = 0;
let totalTruePositives = 0;
let totalFalsePositives = 0;
let totalFalseNegatives = 0;
let noPrediction = 0;
for (let columnId of /** @type {!Array<string>} */ (classIds)) {
if (!currentRow[columnId]) {
currentRow[columnId] = {
positives: 0,
truePositives: 0,
falsePositives: 0,
falseNegatives: 0,
};
}
// Skip no prediciton.
if (columnId == NO_PREDICTION_CLASS_ID) {
noPrediction = currentRow[columnId].falsePositives;
continue;
}
totalPositives += currentRow[columnId].positives;
totalTruePositives += currentRow[columnId].truePositives;
totalFalsePositives += currentRow[columnId].falsePositives;
totalFalseNegatives += currentRow[columnId].falseNegatives;
}
matrix[rowId].totalPositives = totalPositives;
matrix[rowId].totalTruePositives = totalTruePositives;
matrix[rowId].totalFalsePositives = totalFalsePositives;
matrix[rowId].totalFalseNegatives = totalFalseNegatives;
matrix[rowId].totalNoPrediction = noPrediction;
});
return /** @type {!SummaryMatrix} */ (matrix);
}