def draw_float_channel_on_image_array()

in 3_predict/visualization_utils.py [0:0]


def draw_float_channel_on_image_array(image, channel, mask, alpha=0.9,
                                      cmap='YlGn'):
  """Draws a floating point channel on an image array.

  Args:
    image: uint8 numpy array with shape (img_height, img_height, 3)
    channel: float32 numpy array with shape (img_height, img_height). The values
      should be in the range [0, 1], and will be mapped to colors using the
      provided colormap `cmap` argument.
    mask: a uint8 numpy array of shape (img_height, img_height) with
      1-indexed parts (0 for background).
    alpha: transparency value between 0 and 1 (default: 0.9)
    cmap: string with the colormap to use.

  Raises:
    ValueError: On incorrect data type for image or masks.
  """
  if image.dtype != np.uint8:
    raise ValueError('`image` not of type np.uint8')
  if channel.dtype != np.float32:
    raise ValueError('`channel` not of type np.float32')
  if mask.dtype != np.uint8:
    raise ValueError('`mask` not of type np.uint8')
  if image.shape[:2] != channel.shape:
    raise ValueError('The image has spatial dimensions %s but the channel has '
                     'dimensions %s' % (image.shape[:2], channel.shape))
  if image.shape[:2] != mask.shape:
    raise ValueError('The image has spatial dimensions %s but the mask has '
                     'dimensions %s' % (image.shape[:2], mask.shape))

  cm = plt.get_cmap(cmap)
  pil_image = Image.fromarray(image)
  colored_channel = cm(channel)[:, :, :3]
  pil_colored_channel = Image.fromarray(
      np.uint8(colored_channel * 255)).convert('RGBA')
  pil_mask = Image.fromarray(np.uint8(255.0 * alpha * (mask > 0))).convert('L')
  pil_image = Image.composite(pil_colored_channel, pil_image, pil_mask)
  np.copyto(image, np.array(pil_image.convert('RGB')))