export async function confusionMatrix()

in tfjs-vis/src/render/confusion_matrix.ts [69:247]


export async function confusionMatrix(
    container: Drawable, data: ConfusionMatrixData,
    opts: ConfusionMatrixOptions = {}): Promise<void> {
  const options = Object.assign({}, defaultOpts, opts);
  const drawArea = getDrawArea(container);

  // Format data for vega spec; an array of objects, one for for each cell
  // in the matrix.
  const values: MatrixEntry[] = [];

  const inputArray = data.values;
  const tickLabels = data.tickLabels || [];
  const generateLabels = tickLabels.length === 0;

  let nonDiagonalIsAllZeroes = true;
  for (let i = 0; i < inputArray.length; i++) {
    const label = generateLabels ? `Class ${i}` : tickLabels[i];

    if (generateLabels) {
      tickLabels.push(label);
    }

    for (let j = 0; j < inputArray[i].length; j++) {
      const prediction = generateLabels ? `Class ${j}` : tickLabels[j];

      const count = inputArray[i][j];
      if (i === j && !options.shadeDiagonal) {
        values.push({
          label,
          prediction,
          count,
          noFill: true,
        });
      } else {
        values.push({
          label,
          prediction,
          count,
          scaleCount: count,
        });
        // When not shading the diagonal we want to check if there is a non
        // zero value. If all values are zero we will not color them as the
        // scale will be invalid.
        if (count !== 0) {
          nonDiagonalIsAllZeroes = false;
        }
      }
    }
  }

  if (!options.shadeDiagonal && nonDiagonalIsAllZeroes) {
    // User has specified requested not to shade the diagonal but all the other
    // values are zero. We have two choices, don't shade the anything or only
    // shade the diagonal. We choose to shade the diagonal as that is likely
    // more helpful even if it is not what the user specified.
    for (const val of values) {
      if (val.noFill === true) {
        val.noFill = false;
        val.scaleCount = val.count;
      }
    }
  }

  const embedOpts = {
    actions: false,
    mode: 'vega-lite' as Mode,
    defaultStyle: false,
  };

  //@ts-ignore
  const spec: VisualizationSpec = {
    'width': options.width || getDefaultWidth(drawArea),
    'height': options.height || getDefaultHeight(drawArea),
    'padding': 0,
    'autosize': {
      'type': 'fit',
      'contains': 'padding',
      'resize': true,
    },
    'config': {
      'axis': {
        'labelFontSize': options.fontSize,
        'titleFontSize': options.fontSize,
      },
      'text': {'fontSize': options.fontSize},
      'legend': {
        'labelFontSize': options.fontSize,
        'titleFontSize': options.fontSize,
      }
    },
    //@ts-ignore
    'data': {'values': values},
    'encoding': {
      'x': {
        'field': 'prediction',
        'type': 'ordinal',
        'title': options.xLabel || 'prediction',
        // Maintain sort order of the axis if labels is passed in
        'scale': {'domain': tickLabels},
      },
      'y': {
        'field': 'label',
        'type': 'ordinal',
        'title': options.yLabel || 'label',
        // Maintain sort order of the axis if labels is passed in
        'scale': {'domain': tickLabels},
      },
    },
    'layer': [
      {
        // The matrix
        'transform': [
          {'filter': 'datum.noFill != true'},
        ],
        'mark': {
          'type': 'rect',
        },
        'encoding': {
          'color': {
            'field': 'scaleCount',
            'type': 'quantitative',
            //@ts-ignore
            'scale': {'range': options.colorMap},
          },
          'tooltip': [
            {'field': 'label', 'type': 'nominal'},
            {'field': 'prediction', 'type': 'nominal'},
            {'field': 'count', 'type': 'quantitative'},
          ]
        },
      },
    ]
  };

  if (options.shadeDiagonal === false) {
    //@ts-ignore
    spec.layer.push(
        {
          // render unfilled rects for the diagonal
          'transform': [
            {'filter': 'datum.noFill == true'},
          ],
          'mark': {
            'type': 'rect',
            'fill': 'white',
          },
          'encoding': {
            'tooltip': [
              {'field': 'label', 'type': 'nominal'},
              {'field': 'prediction', 'type': 'nominal'},
              {'field': 'count', 'type': 'quantitative'},
            ]
          },
        },
    );
  }

  if (options.showTextOverlay) {
    //@ts-ignore
    spec.layer.push({
      // The text labels
      'mark': {'type': 'text', 'baseline': 'middle'},
      'encoding': {
        'text': {
          'field': 'count',
          'type': 'nominal',
        },
      }
    });
  }

  const colorMap = typeof options.colorMap === 'string' ?
      {scheme: options.colorMap} :
      options.colorMap;
  //@ts-ignore
  spec.layer[0].encoding.color.scale.range = colorMap;

  await embed(drawArea, spec, embedOpts);
}