in d2go/data/dataset_mappers/d2go_dataset_mapper_impl.py [0:0]
def _original_call(self, dataset_dict):
"""
Modified from detectron2's original __call__ in DatasetMapper
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = self._read_image(dataset_dict, format=self.img_format)
if not self.backfill_size:
utils.check_image_size(dataset_dict, image)
image, dataset_dict = self._custom_transform(image, dataset_dict)
inputs = AugInput(image=image)
if "annotations" not in dataset_dict:
transforms = AugmentationList(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens
)(inputs)
image = inputs.image
else:
# pass additional arguments, will only be used when the Augmentation
# takes `annotations` as input
inputs.annotations = dataset_dict["annotations"]
# Crop around an instance if there are instances in the image.
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
inputs.image = crop_tfm.apply_image(image)
transforms = AugmentationList(self.tfm_gens)(inputs)
image = inputs.image
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
if image.ndim == 2:
image = np.expand_dims(image, 2)
dataset_dict["image"] = torch.as_tensor(
image.transpose(2, 0, 1).astype("float32")
)
# Can use uint8 if it turns out to be slow some day
if self.load_proposals:
utils.transform_proposals(
dataset_dict,
image_shape,
transforms,
proposal_topk=self.proposal_topk,
min_box_size=self.proposal_min_box_size,
)
if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
for anno in dataset_dict["annotations"]:
if not self.mask_on:
anno.pop("segmentation", None)
if not self.keypoint_on:
anno.pop("keypoints", None)
annos = [
utils.transform_instance_annotations(
obj,
transforms,
image_shape,
keypoint_hflip_indices=self.keypoint_hflip_indices,
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(
annos, image_shape, mask_format=self.mask_format
)
# Create a tight bounding box from masks, useful when image is cropped
if self.crop_gen and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = read_sem_seg_file_with_prefetch(
dataset_dict.pop("sem_seg_file_name"),
prefetched=dataset_dict.get(PREFETCHED_SEM_SEG_FILE_NAME, None),
)
if len(sem_seg_gt.shape) > 2:
sem_seg_gt = sem_seg_gt.squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
dataset_dict["sem_seg"] = sem_seg_gt
# extend standard D2 semantic segmentation to support multiple segmentation
# files, each file can represent a class
if "multi_sem_seg_file_names" in dataset_dict:
raise NotImplementedError()
if "_post_process_" in dataset_dict:
proc_func = dataset_dict.pop("_post_process_")
dataset_dict = proc_func(dataset_dict)
return dataset_dict