def _create_dataset_parser_fn()

in models/official/mask_rcnn/dataloader.py [0:0]


  def _create_dataset_parser_fn(self, params):
    """Create parser for parsing input data (dictionary)."""
    example_decoder = self._create_example_decoder()

    def _dataset_parser(value):
      """Parse data to a fixed dimension input image and learning targets.

      Args:
        value: A dictionary contains an image and groundtruth annotations.

      Returns:
        features: a dictionary that contains the image and auxiliary
          information. The following describes {key: value} pairs in the
          dictionary.
          image: Image tensor that is preproessed to have normalized value and
            fixed dimension [image_size, image_size, 3]
          image_info: image information that includes the original height and
            width, the scale of the proccessed image to the original image, and
            the scaled height and width.
          source_ids: Source image id. Default value -1 if the source id is
            empty in the groundtruth annotation.
        labels: a dictionary that contains auxiliary information plus (optional)
          labels. The following describes {key: value} pairs in the dictionary.
          `labels` is only for training.
          score_targets_dict: ordered dictionary with keys
            [min_level, min_level+1, ..., max_level]. The values are tensor with
            shape [height_l, width_l, num_anchors]. The height_l and width_l
            represent the dimension of objectiveness score at l-th level.
          box_targets_dict: ordered dictionary with keys
            [min_level, min_level+1, ..., max_level]. The values are tensor with
            shape [height_l, width_l, num_anchors * 4]. The height_l and
            width_l represent the dimension of bounding box regression output at
            l-th level.
          gt_boxes: Groundtruth bounding box annotations. The box is represented
             in [y1, x1, y2, x2] format. The tennsor is padded with -1 to the
             fixed dimension [self._max_num_instances, 4].
          gt_classes: Groundtruth classes annotations. The tennsor is padded
            with -1 to the fixed dimension [self._max_num_instances].
          cropped_gt_masks: groundtrugh masks cropped by the bounding box and
            resized to a fixed size determined by params['gt_mask_size']
      """
      with tf.name_scope('parser'):
        data = example_decoder.decode(value)
        data['groundtruth_is_crowd'] = tf.cond(
            tf.greater(tf.size(data['groundtruth_is_crowd']), 0),
            lambda: data['groundtruth_is_crowd'],
            lambda: tf.zeros_like(data['groundtruth_classes'], dtype=tf.bool))
        image = data['image']
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        orig_image = image
        source_id = data['source_id']
        source_id = tf.where(tf.equal(source_id, tf.constant('')), '-1',
                             source_id)
        source_id = tf.string_to_number(source_id)

        if (self._mode == tf.estimator.ModeKeys.PREDICT or
            self._mode == tf.estimator.ModeKeys.EVAL):
          image = preprocess_ops.normalize_image(image)
          if params['resize_method'] == 'retinanet':
            image, image_info, _, _, _ = preprocess_ops.resize_crop_pad(
                image, params['image_size'], 2 ** params['max_level'])
          else:
            image, image_info, _, _, _ = preprocess_ops.resize_crop_pad_v2(
                image, params['short_side'], params['long_side'],
                2 ** params['max_level'])
          if params['precision'] == 'bfloat16':
            image = tf.cast(image, dtype=tf.bfloat16)

          features = {
              'images': image,
              'image_info': image_info,
              'source_ids': source_id,
          }
          if params['visualize_images_summary']:
            resized_image = tf.image.resize_images(orig_image,
                                                   params['image_size'])
            features['orig_images'] = resized_image
          if (params['include_groundtruth_in_features'] or
              self._mode == tf.estimator.ModeKeys.EVAL):
            labels = _prepare_labels_for_eval(
                data,
                target_num_instances=self._max_num_instances,
                target_polygon_list_len=self._max_num_polygon_list_len,
                use_instance_mask=params['include_mask'])
            return {'features': features, 'labels': labels}
          else:
            return {'features': features}

        elif self._mode == tf.estimator.ModeKeys.TRAIN:
          instance_masks = None
          if self._use_instance_mask:
            instance_masks = data['groundtruth_instance_masks']
          boxes = data['groundtruth_boxes']
          classes = data['groundtruth_classes']
          classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])
          if not params['use_category']:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

          if (params['skip_crowd_during_training'] and
              self._mode == tf.estimator.ModeKeys.TRAIN):
            indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
            classes = tf.gather_nd(classes, indices)
            boxes = tf.gather_nd(boxes, indices)
            if self._use_instance_mask:
              instance_masks = tf.gather_nd(instance_masks, indices)

          image = preprocess_ops.normalize_image(image)
          if params['input_rand_hflip']:
            flipped_results = (
                preprocess_ops.random_horizontal_flip(
                    image, boxes=boxes, masks=instance_masks))
            if self._use_instance_mask:
              image, boxes, instance_masks = flipped_results
            else:
              image, boxes = flipped_results
          # Scaling, jittering and padding.
          if params['resize_method'] == 'retinanet':
            image, image_info, boxes, classes, cropped_gt_masks = (
                preprocess_ops.resize_crop_pad(
                    image,
                    params['image_size'],
                    2 ** params['max_level'],
                    aug_scale_min=params['aug_scale_min'],
                    aug_scale_max=params['aug_scale_max'],
                    boxes=boxes,
                    classes=classes,
                    masks=instance_masks,
                    crop_mask_size=params['gt_mask_size']))
          else:
            image, image_info, boxes, classes, cropped_gt_masks = (
                preprocess_ops.resize_crop_pad_v2(
                    image,
                    params['short_side'],
                    params['long_side'],
                    2 ** params['max_level'],
                    aug_scale_min=params['aug_scale_min'],
                    aug_scale_max=params['aug_scale_max'],
                    boxes=boxes,
                    classes=classes,
                    masks=instance_masks,
                    crop_mask_size=params['gt_mask_size']))
          if cropped_gt_masks is not None:
            cropped_gt_masks = tf.pad(
                cropped_gt_masks,
                paddings=tf.constant([[0, 0,], [2, 2,], [2, 2]]),
                mode='CONSTANT',
                constant_values=0.)

          padded_height, padded_width, _ = image.get_shape().as_list()
          padded_image_size = (padded_height, padded_width)
          input_anchors = anchors.Anchors(
              params['min_level'],
              params['max_level'],
              params['num_scales'],
              params['aspect_ratios'],
              params['anchor_scale'],
              padded_image_size)
          anchor_labeler = anchors.AnchorLabeler(
              input_anchors,
              params['num_classes'],
              params['rpn_positive_overlap'],
              params['rpn_negative_overlap'],
              params['rpn_batch_size_per_im'],
              params['rpn_fg_fraction'])

          # Assign anchors.
          score_targets, box_targets = anchor_labeler.label_anchors(
              boxes, classes)

          # Pad groundtruth data.
          boxes = preprocess_ops.pad_to_fixed_size(
              boxes, -1, [self._max_num_instances, 4])
          classes = preprocess_ops.pad_to_fixed_size(
              classes, -1, [self._max_num_instances, 1])

          # Pads cropped_gt_masks.
          if self._use_instance_mask:
            cropped_gt_masks = tf.reshape(
                cropped_gt_masks, tf.stack([tf.shape(cropped_gt_masks)[0], -1]))
            cropped_gt_masks = preprocess_ops.pad_to_fixed_size(
                cropped_gt_masks, -1,
                [self._max_num_instances, (params['gt_mask_size'] + 4) ** 2])
            cropped_gt_masks = tf.reshape(
                cropped_gt_masks,
                [self._max_num_instances, params['gt_mask_size'] + 4,
                 params['gt_mask_size'] + 4])

          if params['precision'] == 'bfloat16':
            image = tf.cast(image, dtype=tf.bfloat16)

          features = {
              'images': image,
              'image_info': image_info,
              'source_ids': source_id,
          }
          labels = {}
          for level in range(params['min_level'], params['max_level'] + 1):
            labels['score_targets_%d' % level] = score_targets[level]
            labels['box_targets_%d' % level] = box_targets[level]
          labels['gt_boxes'] = boxes
          labels['gt_classes'] = classes
          if self._use_instance_mask:
            labels['cropped_gt_masks'] = cropped_gt_masks
          return features, labels

    return _dataset_parser