in object-detection/visualization_utils.py [0:0]
def draw_float_channel_on_image_array(image, channel, mask, alpha=0.9,
cmap='YlGn'):
"""Draws a floating point channel on an image array.
Args:
image: uint8 numpy array with shape (img_height, img_height, 3)
channel: float32 numpy array with shape (img_height, img_height). The values
should be in the range [0, 1], and will be mapped to colors using the
provided colormap `cmap` argument.
mask: a uint8 numpy array of shape (img_height, img_height) with
1-indexed parts (0 for background).
alpha: transparency value between 0 and 1 (default: 0.9)
cmap: string with the colormap to use.
Raises:
ValueError: On incorrect data type for image or masks.
"""
if image.dtype != np.uint8:
raise ValueError('`image` not of type np.uint8')
if channel.dtype != np.float32:
raise ValueError('`channel` not of type np.float32')
if mask.dtype != np.uint8:
raise ValueError('`mask` not of type np.uint8')
if image.shape[:2] != channel.shape:
raise ValueError('The image has spatial dimensions %s but the channel has '
'dimensions %s' % (image.shape[:2], channel.shape))
if image.shape[:2] != mask.shape:
raise ValueError('The image has spatial dimensions %s but the mask has '
'dimensions %s' % (image.shape[:2], mask.shape))
cm = plt.get_cmap(cmap)
pil_image = Image.fromarray(image)
colored_channel = cm(channel)[:, :, :3]
pil_colored_channel = Image.fromarray(
np.uint8(colored_channel * 255)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0 * alpha * (mask > 0))).convert('L')
pil_image = Image.composite(pil_colored_channel, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB')))