in interactive-visualizers/src/app/app.component.ts [667:808]
async runImageSegmenter(image: HTMLImageElement, index: number):
Promise<void> {
let predictions: number[][] = [];
if (this.modelFormat === 'tflite') {
const startTs = Date.now();
const segmentation = this.tfWebApi.run(image).getSegmentationList()[0];
this.resultsLatency = Date.now() - startTs;
const categoryMask = segmentation.getCategoryMask();
for (let i = 0; i < segmentation.getHeight(); i++) {
predictions.push(Array.from(categoryMask.slice(segmentation.getWidth() * i,
segmentation.getWidth() * (i + 1))));
}
this.labelmap = [];
const coloredLabelList = segmentation.getColoredLabelsList();
for (let i = 0; i < coloredLabelList.length; i++) {
const colorLabel = coloredLabelList[i];
if (colorLabel.getDisplayName()) {
this.labelmap.push(colorLabel.getDisplayName());
} else {
this.labelmap.push(colorLabel.getClassName());
}
if (colorLabel.getR() && colorLabel.getG() && colorLabel.getB()) {
COLOR_LIST[i] = [colorLabel.getR() / 255, colorLabel.getG() / 255, colorLabel.getB() / 255];
}
}
} else {
// Prepare inputs.
const inputTensorMetadata =
this.modelMetadata.tfjs_segmenter_model_metadata.input_tensor_metadata;
const imageTensor = this.prepareImageInput(image, inputTensorMetadata);
// Execute the model.
const outputHeadMetadata =
this.modelMetadata.tfjs_segmenter_model_metadata.output_head_metadata[0];
const outputTensorName =
outputHeadMetadata.semantic_predictions_tensor_name;
const startTs = Date.now();
const outputTensor =
await this.model.executeAsync(imageTensor, outputTensorName) as tf.Tensor;
this.resultsLatency = Date.now() - startTs;
tf.dispose(imageTensor);
const squeezedOutputTensor = outputTensor.squeeze();
tf.dispose(outputTensor);
predictions = await squeezedOutputTensor.array() as number[][];
tf.dispose(squeezedOutputTensor);
// Fetch the labelmap.
if (this.labelmap == null && outputHeadMetadata.labelmap_path != null) {
await this.fetchLabelmap(outputHeadMetadata.labelmap_path);
}
}
// Generate labelmap if not found.
if (this.labelmap == null) {
let maxLabelIndex = 0;
for (const predictionLine of predictions) {
for (const prediction of predictionLine) {
maxLabelIndex = Math.max(maxLabelIndex, prediction);
}
}
this.labelmap = [];
for (let i = 0; i <= maxLabelIndex; ++i) {
this.labelmap.push(`Label ${i}`);
}
}
// Compute label frequencies.
const frequencies = new Array(this.labelmap.length).fill(0);
for (const predictionLine of predictions) {
for (const prediction of predictionLine) {
++frequencies[prediction];
}
}
// Sort labels by decreasing area importance in the query image.
const labelList = frequencies
.map((frequency, listIndex) => {
return {
displayName: this.labelmap[listIndex],
index: listIndex,
frequencyPercent: Math.ceil(
100 * frequency /
(predictions.length * predictions[0].length)),
color: `rgb(${255 * COLOR_LIST[listIndex][0]}, ${255 *
COLOR_LIST[listIndex][1]}, ${255 * COLOR_LIST[listIndex][2]})`,
};
})
.filter(x => x.frequencyPercent > EPSILON)
.sort((a, b) => {
if (a.frequencyPercent > b.frequencyPercent) {
return -1;
}
return 1;
});
if (this.imageSelectedIndex === index) {
// Display results only for the last selected image (as the user may
// have switched selection while inference was running).
this.segmenterPredictions = predictions;
this.segmenterLabelList = labelList;
this.hoveredSegmentationLabel = null;
this.resultsKeyName = 'Type';
this.resultsValueName = 'Percentage of image area';
const imageHtmlElement = document.getElementById('query-image') as HTMLImageElement;
this.queryImageHeight = imageHtmlElement.offsetHeight;
this.queryImageWidth = imageHtmlElement.offsetWidth;
const width = predictions.length;
const height = predictions[0].length;
const canvas = document.getElementById('query-canvas-overlay') as HTMLCanvasElement;
canvas.style.height = `${this.queryImageHeight}px`;
canvas.style.width = `${this.queryImageWidth}px`;
canvas.width = width;
canvas.height = height;
const context = canvas.getContext('2d') as CanvasRenderingContext2D;
context.fillRect(0, 0, width, height);
this.fillSegmentationCanvas();
canvas.addEventListener(
'mousemove', (event => {
const rect = canvas.getBoundingClientRect();
const scaleX = width / rect.width;
const scaleY = height / rect.height;
const x = Math.min(
width - 1,
Math.max(0, Math.round((event.clientX - rect.left) * scaleX)));
const y = Math.min(
height - 1,
Math.max(0, Math.round((event.clientY - rect.top) * scaleY)));
const hoveredLabel = this.segmenterPredictions[y][x];
if (hoveredLabel !== this.hoveredSegmentationLabel) {
this.hoveredSegmentationLabel = hoveredLabel;
this.fillSegmentationCanvas();
}
}));
canvas.addEventListener('mouseout', (event => {
this.hoveredSegmentationLabel = null;
this.fillSegmentationCanvas();
}));
}
}