fn process_crop()

in candle-transformers/src/models/segment_anything/sam.rs [220:320]


    fn process_crop(
        &self,
        img: &Tensor,
        cb: CropBox,
        point_grids: &[(f64, f64)],
    ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
        // Crop the image and calculate embeddings.
        let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
        let img = self.preprocess(&img)?.unsqueeze(0)?;
        let img_embeddings = self.image_encoder.forward(&img)?;

        let crop_w = cb.x1 - cb.x0;
        let crop_h = cb.y1 - cb.y0;

        // Generate masks for this crop.
        let image_pe = self.prompt_encoder.get_dense_pe()?;
        let points = point_grids
            .iter()
            .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
            .collect::<Vec<_>>();

        let mut bboxes = Vec::new();
        for points in points.chunks(64) {
            // Run the model on this batch.
            let points_len = points.len();
            let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
            let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
            let (sparse_prompt_embeddings, dense_prompt_embeddings) =
                self.prompt_encoder
                    .forward(Some((&in_points, &in_labels)), None, None)?;

            let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
                &img_embeddings,
                &image_pe,
                &sparse_prompt_embeddings,
                &dense_prompt_embeddings,
                /* multimask_output */ true,
            )?;
            let low_res_mask = low_res_mask.flatten(0, 1)?;
            let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
            let dev = low_res_mask.device();

            for (i, iou) in iou_predictions.iter().enumerate() {
                // Filter by predicted IoU.
                if *iou < PRED_IOU_THRESH {
                    continue;
                }
                let low_res_mask = low_res_mask.get(i)?;

                // Calculate stability score.
                let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
                    .broadcast_as(low_res_mask.shape())?;
                let intersections = low_res_mask
                    .ge(&bound)?
                    .to_dtype(DType::F32)?
                    .sum_all()?
                    .to_vec0::<f32>()?;
                let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
                    .broadcast_as(low_res_mask.shape())?;
                let unions = low_res_mask
                    .ge(&bound)?
                    .to_dtype(DType::F32)?
                    .sum_all()?
                    .to_vec0::<f32>()?;
                let stability_score = intersections / unions;
                if stability_score < STABILITY_SCORE_THRESHOLD {
                    continue;
                }

                // Threshold masks and calculate boxes.
                let low_res_mask = low_res_mask
                    .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
                    .to_dtype(DType::U32)?;
                let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
                let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
                let min_max_x = min_max_indexes(&low_res_mask_per_x);
                let min_max_y = min_max_indexes(&low_res_mask_per_y);
                if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
                    let bbox = crate::object_detection::Bbox {
                        xmin: x0 as f32,
                        ymin: y0 as f32,
                        xmax: x1 as f32,
                        ymax: y1 as f32,
                        confidence: *iou,
                        data: low_res_mask,
                    };
                    bboxes.push(bbox);
                }
                // TODO:
                // Filter boxes that touch crop boundaries
                // Compress to RLE.
            }
        }

        let mut bboxes = vec![bboxes];
        // Remove duplicates within this crop.
        crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);

        // TODO: Return to the original image frame.
        Ok(bboxes.remove(0))
    }