in 04_detect_segment/utils_box.py [0:0]
def n_experimental_roi_selection_strategy(tile, rois, rois_n, grid_n, n, cell_grow):
assert n == 2 # only implemented for CELL_B=2
normal_rois = n_largest_rois_in_cell_relative(tile, rois, rois_n, grid_n, n, comparator="closest_to_center", expand=1.0)
periph_rois = n_largest_rois_in_cell_relative(tile, rois, rois_n, grid_n, n, comparator="closest_to_center", expand=1.0*cell_grow)
# TODO: count number of non-zero rois in both, then use decision table
# normal_rois periph_rois result
# 0 0 0 0 0 0 (a0)
# x 0 0 0 x x (a0)
# x y 0 0 x y (a1)
# 0 0 z 0 z z (a2)
# 0 0 z t z t (a3)
# x 0 z 0 x z (a4)
# x y z 0 x y (a1)
# x 0 z t x z (a4)
# x y z t x y (a1)
def roi_select(rois):
r1, r2, p1, p2 = tf.unstack(rois, axis=0) # result shape [3]
a0 = tf.stack([r1, r1])
a1 = tf.stack([r1, r2])
a2 = tf.stack([p1, p1])
a3 = tf.stack([p1, p2])
a4 = tf.stack([r1, p1])
a5 = tf.stack([p2, p2])
a6 = tf.stack([r1, p2])
a7 = tf.stack([r2, r2])
a8 = tf.stack([r2, p2])
a9 = tf.stack([r2, p1])
_, _, w = tf.unstack(rois, axis=1) # result shape [4]
nz = tf.greater(w, 0)
zero = tf.zeros(tf.shape(a0))
r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, False, False])), a0, zero)
r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, False, True])), a5, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, True, False])), a2, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [False, False, True, True])), a3, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, False, False])), a7, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, False, True])), a8, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, True, False])), a9, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [False, True, True, True])), a9, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, False, False])), a0, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, False, True])), a6, r) # yes, can happen
r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, True, False])), a4, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [True, False, True, True])), a4, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, False, False])), a1, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, False, True])), a1, r) # cannot happen
r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, True, False])), a1, r)
r = tf.where(tf.reduce_all(tf.equal(nz, [True, True, True, True])), a1, r)
return r
rsnormal_rois = tf.reshape(normal_rois, [grid_n * grid_n, n, 3])
rx, ry, rw = tf.unstack(rsnormal_rois, axis=-1)
rsperiph_rois = tf.reshape(periph_rois, [grid_n * grid_n, n, 3])
px, py, pw = tf.unstack(rsperiph_rois, axis=-1)
roi_exclude = tf.equal(rw, pw)
zero = tf.zeros_like(pw)
pw = tf.where(roi_exclude, zero, pw) # keep in periphery rois only rois that are NOT in normal rois, i.e. rois further than 1 cell radius
rsperiph_rois = tf.stack([px, py, pw], axis=2)
rscombined_rois = tf.concat([rsnormal_rois, rsperiph_rois], axis=1)
rscombined_rois = tf.map_fn(roi_select, rscombined_rois)
combined_rois = tf.reshape(rscombined_rois, [grid_n, grid_n, n, 3])
return combined_rois