def sample_pixel_from_image()

in fairnr/data/data_utils.py [0:0]


def sample_pixel_from_image(
    num_pixel, num_sample, 
    mask=None, ratio=1.0,
    use_bbox=False, 
    center_ratio=1.0,
    width=512,
    patch_size=1):
    
    if patch_size > 1:
        assert (num_pixel % (patch_size * patch_size) == 0) \
            and (num_sample % (patch_size * patch_size) == 0), "size must match"
        _num_pixel = num_pixel // (patch_size * patch_size)
        _num_sample = num_sample // (patch_size * patch_size)
        height = num_pixel // width

        _mask = None if mask is None else \
            mask.reshape(height, width).reshape(
                height//patch_size, patch_size, width//patch_size, patch_size
            ).any(1).any(-1).reshape(-1)
        _width = width // patch_size
        _out = sample_pixel_from_image(_num_pixel, _num_sample, _mask, ratio, use_bbox, _width)
        _x, _y = _out % _width, _out // _width
        x, y = _x * patch_size, _y * patch_size
        x = x[:, None, None] + np.arange(patch_size)[None, :, None] 
        y = y[:, None, None] + np.arange(patch_size)[None, None, :]
        out = x + y * width
        return out.reshape(-1)

    if center_ratio < 1.0:
        r = (1 - center_ratio) / 2.0
        H, W = num_pixel // width, width
        mask0 = np.zeros((H, W))
        mask0[int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1
        mask0 = mask0.reshape(-1)

        if mask is None:
            mask = mask0
        else:
            mask = mask * mask0
    
    if mask is not None:
        mask = (mask > 0.0).astype('float32')

    if (mask is None) or \
        (ratio <= 0.0) or \
        (mask.sum() == 0) or \
        ((1 - mask).sum() == 0):
        return np.random.choice(num_pixel, num_sample)

    if use_bbox:
        mask = mask.reshape(-1, width)
        x, y = np.where(mask == 1)
        mask = np.zeros_like(mask)
        mask[x.min(): x.max()+1, y.min(): y.max()+1] = 1.0
        mask = mask.reshape(-1)

    try:
        probs = mask * ratio / (mask.sum()) + (1 - mask) / (num_pixel - mask.sum()) * (1 - ratio)
        # x = np.random.choice(num_pixel, num_sample, True, p=probs)
        return np.random.choice(num_pixel, num_sample, True, p=probs)
    
    except Exception:
        return np.random.choice(num_pixel, num_sample)