def zero_where()

in 04_detect_segment/utils_box.py [0:0]


def zero_where(rois, mask):
    coordinate_shape = rois.get_shape()[-1:]  # [4] for coordinates like [x1, x2, y1, y2]
    shape = tf.shape(mask)
    shape = tf.ones_like(shape, tf.int32)
    shape = tf.concat([shape, coordinate_shape], axis=0)  # tile shape like [1, 1,.., 4]
    mask = tf.expand_dims(mask, axis=-1)
    mask = tf.tile(mask, shape)  # replicates the booleans along the last dimension
    return tf.where(mask, tf.zeros_like(rois), rois)