def n_experimental_roi_selection_strategy()

in 04_detect_segment/utils_box.py [0:0]


def n_experimental_roi_selection_strategy(tile, rois, rois_n, grid_n, n, cell_grow):
    assert n == 2  # only implemented for CELL_B=2
    normal_rois = n_largest_rois_in_cell_relative(tile, rois, rois_n, grid_n, n, comparator="closest_to_center", expand=1.0)
    periph_rois = n_largest_rois_in_cell_relative(tile, rois, rois_n, grid_n, n, comparator="closest_to_center", expand=1.0*cell_grow)

    # TODO: count number of non-zero rois in both, then use decision table
    # normal_rois   periph_rois   result
    #    0     0      0     0      0    0   (a0)
    #    x     0      0     0      x    x   (a0)
    #    x     y      0     0      x    y   (a1)
    #    0     0      z     0      z    z   (a2)
    #    0     0      z     t      z    t   (a3)
    #    x     0      z     0      x    z   (a4)
    #    x     y      z     0      x    y   (a1)
    #    x     0      z     t      x    z   (a4)
    #    x     y      z     t      x    y   (a1)

    def roi_select(rois):
        r1, r2, p1, p2 = tf.unstack(rois, axis=0)  # result shape [3]
        a0 = tf.stack([r1, r1])
        a1 = tf.stack([r1, r2])
        a2 = tf.stack([p1, p1])
        a3 = tf.stack([p1, p2])
        a4 = tf.stack([r1, p1])
        a5 = tf.stack([p2, p2])
        a6 = tf.stack([r1, p2])
        a7 = tf.stack([r2, r2])
        a8 = tf.stack([r2, p2])
        a9 = tf.stack([r2, p1])
        _, _, w = tf.unstack(rois, axis=1)  # result shape [4]
        nz = tf.greater(w, 0)
        zero = tf.zeros(tf.shape(a0))
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, False, False])), a0, zero)
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, False, True])), a5, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, True, False])), a2, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, True, True])), a3, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, False, False])), a7, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, False, True])), a8, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, True, False])), a9, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, True, True])), a9, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, False, False])), a0, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, False, True])), a6, r)  # yes, can happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, True, False])), a4, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, True, True])), a4, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, False, False])), a1, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, False, True])), a1, r)  # cannot happen
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, True, False])), a1, r)
        r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, True, True])), a1, r)
        return r

    rsnormal_rois = tf.reshape(normal_rois, [grid_n * grid_n, n, 3])
    rx, ry, rw = tf.unstack(rsnormal_rois, axis=-1)
    rsperiph_rois = tf.reshape(periph_rois, [grid_n * grid_n, n, 3])
    px, py, pw = tf.unstack(rsperiph_rois, axis=-1)
    roi_exclude = tf.equal(rw, pw)
    zero = tf.zeros_like(pw)
    pw = tf.where(roi_exclude, zero, pw)  # keep in periphery rois only rois that are NOT in normal rois, i.e. rois further than 1 cell radius
    rsperiph_rois = tf.stack([px, py, pw], axis=2)
    rscombined_rois = tf.concat([rsnormal_rois, rsperiph_rois], axis=1)
    rscombined_rois = tf.map_fn(roi_select, rscombined_rois)
    combined_rois = tf.reshape(rscombined_rois, [grid_n, grid_n, n, 3])
    return combined_rois