export function post_process_object_detection()

in src/base/image_processors_utils.js [90:160]


export function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) {
    const out_logits = outputs.logits;
    const out_bbox = outputs.pred_boxes;
    const [batch_size, num_boxes, num_classes] = out_logits.dims;

    if (target_sizes !== null && target_sizes.length !== batch_size) {
        throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
    }
    let toReturn = [];
    for (let i = 0; i < batch_size; ++i) {
        let target_size = target_sizes !== null ? target_sizes[i] : null;
        let info = {
            boxes: [],
            classes: [],
            scores: []
        }
        let logits = out_logits[i];
        let bbox = out_bbox[i];

        for (let j = 0; j < num_boxes; ++j) {
            let logit = logits[j];

            let indices = [];
            let probs;
            if (is_zero_shot) {
                // Get indices of classes with high enough probability
                probs = logit.sigmoid().data;
                for (let k = 0; k < probs.length; ++k) {
                    if (probs[k] > threshold) {
                        indices.push(k);
                    }
                }

            } else {
                // Get most probable class
                let maxIndex = max(logit.data)[1];

                if (maxIndex === num_classes - 1) {
                    // This is the background class, skip it
                    continue;
                }
                // Compute softmax over classes
                probs = softmax(logit.data);

                if (probs[maxIndex] < threshold) {
                    continue;
                }
                indices.push(maxIndex);
            }

            for (const index of indices) {

                // Some class has a high enough probability
                /** @type {number[]} */
                let box = bbox[j].data;

                // convert to [x0, y0, x1, y1] format
                box = center_to_corners_format(box)
                if (target_size !== null) {
                    box = box.map((x, i) => x * target_size[(i + 1) % 2])
                }

                info.boxes.push(box);
                info.classes.push(index);
                info.scores.push(probs[index]);
            }
        }
        toReturn.push(info);
    }
    return toReturn;
}