in 3_predict/visualization_utils.py [0:0]
def create_visualization_fn(category_index,
include_masks=False,
include_keypoints=False,
include_keypoint_scores=False,
include_track_ids=False,
**kwargs):
"""Constructs a visualization function that can be wrapped in a py_func.
py_funcs only accept positional arguments. This function returns a suitable
function with the correct positional argument mapping. The positional
arguments in order are:
0: image
1: boxes
2: classes
3: scores
[4]: masks (optional)
[4-5]: keypoints (optional)
[4-6]: keypoint_scores (optional)
[4-7]: track_ids (optional)
-- Example 1 --
vis_only_masks_fn = create_visualization_fn(category_index,
include_masks=True, include_keypoints=False, include_track_ids=False,
**kwargs)
image = tf.py_func(vis_only_masks_fn,
inp=[image, boxes, classes, scores, masks],
Tout=tf.uint8)
-- Example 2 --
vis_masks_and_track_ids_fn = create_visualization_fn(category_index,
include_masks=True, include_keypoints=False, include_track_ids=True,
**kwargs)
image = tf.py_func(vis_masks_and_track_ids_fn,
inp=[image, boxes, classes, scores, masks, track_ids],
Tout=tf.uint8)
Args:
category_index: a dict that maps integer ids to category dicts. e.g.
{1: {1: 'dog'}, 2: {2: 'cat'}, ...}
include_masks: Whether masks should be expected as a positional argument in
the returned function.
include_keypoints: Whether keypoints should be expected as a positional
argument in the returned function.
include_keypoint_scores: Whether keypoint scores should be expected as a
positional argument in the returned function.
include_track_ids: Whether track ids should be expected as a positional
argument in the returned function.
**kwargs: Additional kwargs that will be passed to
visualize_boxes_and_labels_on_image_array.
Returns:
Returns a function that only takes tensors as positional arguments.
"""
def visualization_py_func_fn(*args):
"""Visualization function that can be wrapped in a tf.py_func.
Args:
*args: First 4 positional arguments must be:
image - uint8 numpy array with shape (img_height, img_width, 3).
boxes - a numpy array of shape [N, 4].
classes - a numpy array of shape [N].
scores - a numpy array of shape [N] or None.
-- Optional positional arguments --
instance_masks - a numpy array of shape [N, image_height, image_width].
keypoints - a numpy array of shape [N, num_keypoints, 2].
keypoint_scores - a numpy array of shape [N, num_keypoints].
track_ids - a numpy array of shape [N] with unique track ids.
Returns:
uint8 numpy array with shape (img_height, img_width, 3) with overlaid
boxes.
"""
image = args[0]
boxes = args[1]
classes = args[2]
scores = args[3]
masks = keypoints = keypoint_scores = track_ids = None
pos_arg_ptr = 4 # Positional argument for first optional tensor (masks).
if include_masks:
masks = args[pos_arg_ptr]
pos_arg_ptr += 1
if include_keypoints:
keypoints = args[pos_arg_ptr]
pos_arg_ptr += 1
if include_keypoint_scores:
keypoint_scores = args[pos_arg_ptr]
pos_arg_ptr += 1
if include_track_ids:
track_ids = args[pos_arg_ptr]
return visualize_boxes_and_labels_on_image_array(
image,
boxes,
classes,
scores,
category_index=category_index,
instance_masks=masks,
keypoints=keypoints,
keypoint_scores=keypoint_scores,
track_ids=track_ids,
**kwargs)
return visualization_py_func_fn