in tfjs-vis/src/render/confusion_matrix.ts [69:247]
export async function confusionMatrix(
container: Drawable, data: ConfusionMatrixData,
opts: ConfusionMatrixOptions = {}): Promise<void> {
const options = Object.assign({}, defaultOpts, opts);
const drawArea = getDrawArea(container);
// Format data for vega spec; an array of objects, one for for each cell
// in the matrix.
const values: MatrixEntry[] = [];
const inputArray = data.values;
const tickLabels = data.tickLabels || [];
const generateLabels = tickLabels.length === 0;
let nonDiagonalIsAllZeroes = true;
for (let i = 0; i < inputArray.length; i++) {
const label = generateLabels ? `Class ${i}` : tickLabels[i];
if (generateLabels) {
tickLabels.push(label);
}
for (let j = 0; j < inputArray[i].length; j++) {
const prediction = generateLabels ? `Class ${j}` : tickLabels[j];
const count = inputArray[i][j];
if (i === j && !options.shadeDiagonal) {
values.push({
label,
prediction,
count,
noFill: true,
});
} else {
values.push({
label,
prediction,
count,
scaleCount: count,
});
// When not shading the diagonal we want to check if there is a non
// zero value. If all values are zero we will not color them as the
// scale will be invalid.
if (count !== 0) {
nonDiagonalIsAllZeroes = false;
}
}
}
}
if (!options.shadeDiagonal && nonDiagonalIsAllZeroes) {
// User has specified requested not to shade the diagonal but all the other
// values are zero. We have two choices, don't shade the anything or only
// shade the diagonal. We choose to shade the diagonal as that is likely
// more helpful even if it is not what the user specified.
for (const val of values) {
if (val.noFill === true) {
val.noFill = false;
val.scaleCount = val.count;
}
}
}
const embedOpts = {
actions: false,
mode: 'vega-lite' as Mode,
defaultStyle: false,
};
//@ts-ignore
const spec: VisualizationSpec = {
'width': options.width || getDefaultWidth(drawArea),
'height': options.height || getDefaultHeight(drawArea),
'padding': 0,
'autosize': {
'type': 'fit',
'contains': 'padding',
'resize': true,
},
'config': {
'axis': {
'labelFontSize': options.fontSize,
'titleFontSize': options.fontSize,
},
'text': {'fontSize': options.fontSize},
'legend': {
'labelFontSize': options.fontSize,
'titleFontSize': options.fontSize,
}
},
//@ts-ignore
'data': {'values': values},
'encoding': {
'x': {
'field': 'prediction',
'type': 'ordinal',
'title': options.xLabel || 'prediction',
// Maintain sort order of the axis if labels is passed in
'scale': {'domain': tickLabels},
},
'y': {
'field': 'label',
'type': 'ordinal',
'title': options.yLabel || 'label',
// Maintain sort order of the axis if labels is passed in
'scale': {'domain': tickLabels},
},
},
'layer': [
{
// The matrix
'transform': [
{'filter': 'datum.noFill != true'},
],
'mark': {
'type': 'rect',
},
'encoding': {
'color': {
'field': 'scaleCount',
'type': 'quantitative',
//@ts-ignore
'scale': {'range': options.colorMap},
},
'tooltip': [
{'field': 'label', 'type': 'nominal'},
{'field': 'prediction', 'type': 'nominal'},
{'field': 'count', 'type': 'quantitative'},
]
},
},
]
};
if (options.shadeDiagonal === false) {
//@ts-ignore
spec.layer.push(
{
// render unfilled rects for the diagonal
'transform': [
{'filter': 'datum.noFill == true'},
],
'mark': {
'type': 'rect',
'fill': 'white',
},
'encoding': {
'tooltip': [
{'field': 'label', 'type': 'nominal'},
{'field': 'prediction', 'type': 'nominal'},
{'field': 'count', 'type': 'quantitative'},
]
},
},
);
}
if (options.showTextOverlay) {
//@ts-ignore
spec.layer.push({
// The text labels
'mark': {'type': 'text', 'baseline': 'middle'},
'encoding': {
'text': {
'field': 'count',
'type': 'nominal',
},
}
});
}
const colorMap = typeof options.colorMap === 'string' ?
{scheme: options.colorMap} :
options.colorMap;
//@ts-ignore
spec.layer[0].encoding.color.scale.range = colorMap;
await embed(drawArea, spec, embedOpts);
}