export function post_process_panoptic_segmentation()

in src/base/image_processors_utils.js [476:545]


export function post_process_panoptic_segmentation(
    outputs,
    threshold = 0.5,
    mask_threshold = 0.5,
    overlap_mask_area_threshold = 0.8,
    label_ids_to_fuse = null,
    target_sizes = null,
) {
    if (label_ids_to_fuse === null) {
        console.warn("`label_ids_to_fuse` unset. No instance will be fused.")
        label_ids_to_fuse = new Set();
    }

    const class_queries_logits = outputs.class_queries_logits ?? outputs.logits; // [batch_size, num_queries, num_classes+1]
    const masks_queries_logits = outputs.masks_queries_logits ?? outputs.pred_masks; // [batch_size, num_queries, height, width]

    const mask_probs = masks_queries_logits.sigmoid()  // [batch_size, num_queries, height, width]

    let [batch_size, num_queries, num_labels] = class_queries_logits.dims;
    num_labels -= 1; // Remove last class (background)

    if (target_sizes !== null && target_sizes.length !== batch_size) {
        throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
    }

    let toReturn = [];
    for (let i = 0; i < batch_size; ++i) {
        let target_size = target_sizes !== null ? target_sizes[i] : null;

        let class_logits = class_queries_logits[i];
        let mask_logits = mask_probs[i];

        let [mask_probs_item, pred_scores_item, pred_labels_item] = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels);

        if (pred_labels_item.length === 0) {
            // No mask found
            let [height, width] = target_size ?? mask_logits.dims.slice(-2);

            let segmentation = new Tensor(
                'int32',
                new Int32Array(height * width).fill(-1),
                [height, width]
            )
            toReturn.push({
                segmentation: segmentation,
                segments_info: []
            });
            continue;
        }


        // Get segmentation map and segment information of batch item
        let [segmentation, segments] = compute_segments(
            mask_probs_item,
            pred_scores_item,
            pred_labels_item,
            mask_threshold,
            overlap_mask_area_threshold,
            label_ids_to_fuse,
            target_size,
        )

        toReturn.push({
            segmentation: segmentation,
            segments_info: segments
        })
    }

    return toReturn;
}