def draw_heatmaps_on_image_tensors()

in object-detection/visualization_utils.py [0:0]


def draw_heatmaps_on_image_tensors(images,
                                   heatmaps,
                                   apply_sigmoid=False):
  """Draws heatmaps on batch of image tensors.
  Args:
    images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
      channels will be ignored. If C = 1, then we convert the images to RGB
      images.
    heatmaps: [N, h, w, channel] float32 tensor of heatmaps. Note that the
      heatmaps will be resized to match the input image size before overlaying
      the heatmaps with input images. Theoretically the heatmap height width
      should have the same aspect ratio as the input image to avoid potential
      misalignment introduced by the image resize.
    apply_sigmoid: Whether to apply a sigmoid layer on top of the heatmaps. If
      the heatmaps come directly from the prediction logits, then we should
      apply the sigmoid layer to make sure the values are in between [0.0, 1.0].
  Returns:
    4D image tensor of type uint8, with heatmaps overlaid on top.
  """
  # Additional channels are being ignored.
  if images.shape[3] > 3:
    images = images[:, :, :, 0:3]
  elif images.shape[3] == 1:
    images = tf.image.grayscale_to_rgb(images)

  _, height, width, _ = shape_utils.combined_static_and_dynamic_shape(images)
  if apply_sigmoid:
    heatmaps = tf.math.sigmoid(heatmaps)
  resized_heatmaps = tf.image.resize(heatmaps, size=[height, width])

  elems = [images, resized_heatmaps]

  def draw_heatmaps(image_and_heatmaps):
    """Draws heatmaps on image."""
    image_with_heatmaps = tf.py_function(
        draw_heatmaps_on_image_array,
        image_and_heatmaps,
        tf.uint8)
    return image_with_heatmaps
  images = tf.map_fn(draw_heatmaps, elems, dtype=tf.uint8, back_prop=False)
  return images