in src/base/image_processors_utils.js [323:422]
function compute_segments(
mask_probs,
pred_scores,
pred_labels,
mask_threshold,
overlap_mask_area_threshold,
label_ids_to_fuse = null,
target_size = null,
) {
const [height, width] = target_size ?? mask_probs[0].dims;
const segmentation = new Tensor(
'int32',
new Int32Array(height * width),
[height, width]
);
const segments = [];
// 1. If target_size is not null, we need to resize the masks to the target size
if (target_size !== null) {
// resize the masks to the target size
for (let i = 0; i < mask_probs.length; ++i) {
mask_probs[i] = interpolate(mask_probs[i], target_size, 'bilinear', false);
}
}
// 2. Weigh each mask by its prediction score
// NOTE: `mask_probs` is updated in-place
//
// Temporary storage for the best label/scores for each pixel ([height, width]):
const mask_labels = new Int32Array(mask_probs[0].data.length);
const bestScores = new Float32Array(mask_probs[0].data.length);
for (let i = 0; i < mask_probs.length; ++i) {
let score = pred_scores[i];
const mask_probs_i_data = mask_probs[i].data;
for (let j = 0; j < mask_probs_i_data.length; ++j) {
mask_probs_i_data[j] *= score
if (mask_probs_i_data[j] > bestScores[j]) {
mask_labels[j] = i;
bestScores[j] = mask_probs_i_data[j];
}
}
}
let current_segment_id = 0;
// let stuff_memory_list = {}
const segmentation_data = segmentation.data;
for (let k = 0; k < pred_labels.length; ++k) {
const pred_class = pred_labels[k];
// TODO add `should_fuse`
// let should_fuse = pred_class in label_ids_to_fuse
// Check if mask exists and large enough to be a segment
const [mask_exists, mask_k] = check_segment_validity(
mask_labels,
mask_probs,
k,
mask_threshold,
overlap_mask_area_threshold
)
if (!mask_exists) {
// Nothing to see here
continue;
}
// TODO
// if (pred_class in stuff_memory_list) {
// current_segment_id = stuff_memory_list[pred_class]
// } else {
// current_segment_id += 1;
// }
++current_segment_id;
// Add current object segment to final segmentation map
for (const index of mask_k) {
segmentation_data[index] = current_segment_id;
}
segments.push({
id: current_segment_id,
label_id: pred_class,
// was_fused: should_fuse, TODO
score: pred_scores[k],
})
// TODO
// if(should_fuse){
// stuff_memory_list[pred_class] = current_segment_id
// }
}
return [segmentation, segments];
}