def batch_intersection_over_union()

in 04_detect_segment/utils_box.py [0:0]


    def batch_intersection_over_union(cls, rects1, rects2, tile_size):
        """Computes the intersection over union of two sets of rectangles.
        The actual computation is:
            intersection_area(union(rects1), union(rects2)) / union_area(rects1, rects2)
        This works on batches of rectangles but instantiates a bitmap of size tile_size to compute
        the intersections and is therefore both slow and memory-intensive. Use sparingly.

        Args:
            rects1: detected rectangles, shape [batch, n, 4] with coordinates x1, y1, x2, y2
            rects2: ground truth rectangles, shape [batch, n, 4] with coordinates x1, y1, x2, y2
                The size of the rectangles is [x2-x1, y2-y1].
            tile_size: size of the images where the rectangles apply (also size of internal bitmaps)

        Returns:
            An array of shape [batch]. Use batch_mean() to correctly average it.
            Returns 1 in cases in the batch where both rects1 and rects2 contain
            no rectangles (correctly detected nothing when there was nothing to detect).
        """
        batch = tf.shape(rects1)[0]
        n1 = tf.shape(rects1)[1]  # number of rectangles per batch element in rect1
        n2 = tf.shape(rects2)[1]  # number of rectangles per batch element in rect2
        linmap1 = cls.__iou_gen_linmap(batch, n1, tile_size)
        linmap2 = cls.__iou_gen_linmap(batch, n2, tile_size)
        map1 = cls.__iou_gen_rectmap(linmap1, rects1, tile_size)  # shape [batch, n, tile_size, tile_size]
        map2 = cls.__iou_gen_rectmap(linmap2, rects2, tile_size)  # shape [batch, n, tile_size, tile_size]
        union_all = tf.concat([map1, map2], axis=1)
        union_all = tf.reduce_any(union_all, axis=1)
        union1 = tf.reduce_any(map1, axis=1)  # shape [batch, SIZE, SIZE]
        union2 = tf.reduce_any(map2, axis=1)  # shape [batch, SIZE, SIZE]
        intersect = tf.logical_and(union1, union2)  # shape [batch, SIZE, SIZE]
        union_area = tf.reduce_sum(tf.cast(union_all, tf.float32), axis=[1, 2])  #  can still be empty because of rectangle cropping
        safe_union_area = tf.where(tf.equal(union_area, 0.0), tf.ones_like(union_area), union_area)
        inter_area = tf.reduce_sum(tf.cast(intersect, tf.float32), axis=[1, 2])
        safe_inter_area = tf.where(tf.equal(union_area, 0.0), tf.ones_like(inter_area), inter_area)
        iou = safe_inter_area / safe_union_area  # returns 0 even if the union is null
        return iou