computeSummaryMatrix_()

in tensorflow_model_analysis/frontend/tfma-multi-class-confusion-matrix-at-thresholds/tfma-multi-class-confusion-matrix-at-thresholds.js [453:530]


  computeSummaryMatrix_(entries, multiLabel, classIds) {
    if (!entries || !classIds) {
      return undefined;
    }

    const matrix = {};
    for (const entry of entries) {
      const actual = entry[FieldNames.ACTUAL_CLASS_ID] || 0;
      const predicted = entry[FieldNames.PREDICTED_CLASS_ID] || 0;
      if (!matrix[actual]) {
        matrix[actual] = {entries: {}};
      }
      const row = matrix[actual].entries;
      const isDiagonal = predicted == actual;
      const truePositives = entry[FieldNames.TRUE_POSITIVES] || 0;
      const falsePositives = entry[FieldNames.FALSE_POSITIVES] || 0;
      const numWeightedExamples = entry[FieldNames.NUM_WEIGHTED_EXAMPLES] || 0;

      const truePositivesToUse =
          multiLabel ? truePositives : (isDiagonal ? numWeightedExamples : 0);
      const falsePositivesToUse =
          multiLabel ? falsePositives : (isDiagonal ? 0 : numWeightedExamples);
      const falseNegatives = entry[FieldNames.FALSE_NEGATIVES] || 0;

      row[predicted] = {
        positives: truePositivesToUse + falsePositivesToUse,
        truePositives: truePositivesToUse,
        falsePositives: falsePositivesToUse,
        falseNegatives: falseNegatives,
      };
    }

    classIds.forEach((rowId) => {
      if (rowId == NO_PREDICTION_CLASS_ID) {
        // Do not create a row for no prediction.
        return;
      }

      // Fill in holes in the matrix.
      if (!matrix[rowId]) {
        matrix[rowId] = {entries: {}};
      }
      const currentRow = matrix[rowId].entries;
      let totalPositives = 0;
      let totalTruePositives = 0;
      let totalFalsePositives = 0;
      let totalFalseNegatives = 0;
      let noPrediction = 0;

      for (let columnId of /** @type {!Array<string>} */ (classIds)) {
        if (!currentRow[columnId]) {
          currentRow[columnId] = {
            positives: 0,
            truePositives: 0,
            falsePositives: 0,
            falseNegatives: 0,
          };
        }

        // Skip no prediciton.
        if (columnId == NO_PREDICTION_CLASS_ID) {
          noPrediction = currentRow[columnId].falsePositives;
          continue;
        }
        totalPositives += currentRow[columnId].positives;
        totalTruePositives += currentRow[columnId].truePositives;
        totalFalsePositives += currentRow[columnId].falsePositives;
        totalFalseNegatives += currentRow[columnId].falseNegatives;
      }

      matrix[rowId].totalPositives = totalPositives;
      matrix[rowId].totalTruePositives = totalTruePositives;
      matrix[rowId].totalFalsePositives = totalFalsePositives;
      matrix[rowId].totalFalseNegatives = totalFalseNegatives;
      matrix[rowId].totalNoPrediction = noPrediction;
    });
    return /** @type {!SummaryMatrix} */ (matrix);
  }