def draw_part_mask_on_image_array()

in object-detection/visualization_utils.py [0:0]


def draw_part_mask_on_image_array(image, mask, alpha=0.4, num_parts=24):
  """Draws part mask on an image.
  Args:
    image: uint8 numpy array with shape (img_height, img_height, 3)
    mask: a uint8 numpy array of shape (img_height, img_height) with
      1-indexed parts (0 for background).
    alpha: transparency value between 0 and 1 (default: 0.4)
    num_parts: the maximum number of parts that may exist in the image (default
      24 for DensePose).
  Raises:
    ValueError: On incorrect data type for image or masks.
  """
  if image.dtype != np.uint8:
    raise ValueError('`image` not of type np.uint8')
  if mask.dtype != np.uint8:
    raise ValueError('`mask` not of type np.uint8')
  if image.shape[:2] != mask.shape:
    raise ValueError('The image has spatial dimensions %s but the mask has '
                     'dimensions %s' % (image.shape[:2], mask.shape))

  pil_image = Image.fromarray(image)
  part_colors = np.zeros_like(image)
  mask_1_channel = mask[:, :, np.newaxis]
  for i, color in enumerate(STANDARD_COLORS[:num_parts]):
    rgb = np.array(ImageColor.getrgb(color), dtype=np.uint8)
    part_colors += (mask_1_channel == i + 1) * rgb[np.newaxis, np.newaxis, :]
  pil_part_colors = Image.fromarray(np.uint8(part_colors)).convert('RGBA')
  pil_mask = Image.fromarray(np.uint8(255.0 * alpha * (mask > 0))).convert('L')
  pil_image = Image.composite(pil_part_colors, pil_image, pil_mask)
  np.copyto(image, np.array(pil_image.convert('RGB')))