in object-detection/visualization_utils.py [0:0]
def draw_heatmaps_on_image_tensors(images,
heatmaps,
apply_sigmoid=False):
"""Draws heatmaps on batch of image tensors.
Args:
images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional
channels will be ignored. If C = 1, then we convert the images to RGB
images.
heatmaps: [N, h, w, channel] float32 tensor of heatmaps. Note that the
heatmaps will be resized to match the input image size before overlaying
the heatmaps with input images. Theoretically the heatmap height width
should have the same aspect ratio as the input image to avoid potential
misalignment introduced by the image resize.
apply_sigmoid: Whether to apply a sigmoid layer on top of the heatmaps. If
the heatmaps come directly from the prediction logits, then we should
apply the sigmoid layer to make sure the values are in between [0.0, 1.0].
Returns:
4D image tensor of type uint8, with heatmaps overlaid on top.
"""
# Additional channels are being ignored.
if images.shape[3] > 3:
images = images[:, :, :, 0:3]
elif images.shape[3] == 1:
images = tf.image.grayscale_to_rgb(images)
_, height, width, _ = shape_utils.combined_static_and_dynamic_shape(images)
if apply_sigmoid:
heatmaps = tf.math.sigmoid(heatmaps)
resized_heatmaps = tf.image.resize(heatmaps, size=[height, width])
elems = [images, resized_heatmaps]
def draw_heatmaps(image_and_heatmaps):
"""Draws heatmaps on image."""
image_with_heatmaps = tf.py_function(
draw_heatmaps_on_image_array,
image_and_heatmaps,
tf.uint8)
return image_with_heatmaps
images = tf.map_fn(draw_heatmaps, elems, dtype=tf.uint8, back_prop=False)
return images