assets/aml-benchmark/scripts/data_loaders/mscoco.py (881 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Data Loading Script for MSCOCO. Adapted from https://github.com/shunk031/huggingface-datasets_MSCOCO/blob/main/MSCOCO.py to only download validation images. """ import abc import json import logging import os from collections import defaultdict from dataclasses import asdict, dataclass from typing import ( Any, Dict, Final, Iterator, List, Literal, Optional, Sequence, Tuple, TypedDict, Union, get_args, ) import datasets as ds import numpy as np from datasets.data_files import DataFilesDict from PIL import Image from PIL.Image import Image as PilImage from pycocotools import mask as cocomask from tqdm.auto import tqdm logger = logging.getLogger(__name__) JsonDict = Dict[str, Any] ImageId = int AnnotationId = int LicenseId = int CategoryId = int Bbox = Tuple[float, float, float, float] MscocoSplits = Literal["train", "val", "test"] KEYPOINT_STATE: Final[List[str]] = ["unknown", "invisible", "visible"] _CITATION = """ """ _DESCRIPTION = """ """ _HOMEPAGE = """ """ _LICENSE = "https://creativecommons.org/licenses/by/4.0/legalcode" _URLS = { "2014": { "images": { # "train": "http://images.cocodataset.org/zips/train2014.zip", "validation": "http://images.cocodataset.org/zips/val2014.zip", # "test": "http://images.cocodataset.org/zips/test2014.zip", }, "annotations": { "train_validation": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip", "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2014.zip", }, }, "2015": { "images": { "test": "http://images.cocodataset.org/zips/test2015.zip", }, "annotations": { "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2015.zip", }, }, "2017": { "images": { # "train": "http://images.cocodataset.org/zips/train2017.zip", "validation": "http://images.cocodataset.org/zips/val2017.zip", # "test": "http://images.cocodataset.org/zips/test2017.zip", # "unlabeled": "http://images.cocodataset.org/zips/unlabeled2017.zip", }, "annotations": { "train_validation": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip", "stuff_train_validation": "http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip", "panoptic_train_validation": ( "http://images.cocodataset.org/annotations/" "panoptic_annotations_trainval2017.zip" ), "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2017.zip", "unlabeled": "http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip", }, }, } CATEGORIES: Final[List[str]] = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ] SUPER_CATEGORIES: Final[List[str]] = [ "person", "vehicle", "outdoor", "animal", "accessory", "sports", "kitchen", "food", "furniture", "electronic", "appliance", "indoor", ] @dataclass class AnnotationInfo(object): # noqa: D101,D102,D103,D107 description: str url: str version: str year: str contributor: str date_created: str @classmethod def from_dict(cls, json_dict: JsonDict) -> "AnnotationInfo": # noqa: D101,D102,D103,D107 return cls(**json_dict) @dataclass class LicenseData(object): # noqa: D101,D102,D103,D107 url: str license_id: LicenseId name: str @classmethod def from_dict(cls, json_dict: JsonDict) -> "LicenseData": # noqa: D101,D102,D103,D107 return cls( license_id=json_dict["id"], url=json_dict["url"], name=json_dict["name"], ) @dataclass class ImageData(object): # noqa: D101,D102,D103,D107 image_id: ImageId license_id: LicenseId file_name: str coco_url: str height: int width: int date_captured: str flickr_url: str @classmethod def from_dict(cls, json_dict: JsonDict) -> "ImageData": # noqa: D101,D102,D103,D107 return cls( image_id=json_dict["id"], license_id=json_dict["license"], file_name=json_dict["file_name"], coco_url=json_dict["coco_url"], height=json_dict["height"], width=json_dict["width"], date_captured=json_dict["date_captured"], flickr_url=json_dict["flickr_url"], ) @property def shape(self) -> Tuple[int, int]: # noqa: D101,D102,D103,D107 return (self.height, self.width) @dataclass class CategoryData(object): # noqa: D101,D102,D103,D107 category_id: int name: str supercategory: str @classmethod def from_dict(cls, json_dict: JsonDict) -> "CategoryData": # noqa: D101,D102,D103,D107 return cls( category_id=json_dict["id"], name=json_dict["name"], supercategory=json_dict["supercategory"], ) @dataclass class AnnotationData(object): # noqa: D101,D102,D103,D107 annotation_id: AnnotationId image_id: ImageId @dataclass class CaptionsAnnotationData(AnnotationData): # noqa: D101,D102,D103,D107 caption: str @classmethod def from_dict(cls, json_dict: JsonDict) -> "CaptionsAnnotationData": # noqa: D101,D102,D103,D107 return cls( annotation_id=json_dict["id"], image_id=json_dict["image_id"], caption=json_dict["caption"], ) class UncompressedRLE(TypedDict): # noqa: D101,D102,D103,D107 counts: List[int] size: Tuple[int, int] class CompressedRLE(TypedDict): # noqa: D101,D102,D103,D107 counts: bytes size: Tuple[int, int] @dataclass class InstancesAnnotationData(AnnotationData): # noqa: D101,D102,D103,D107 segmentation: Union[np.ndarray, CompressedRLE] area: float iscrowd: bool bbox: Tuple[float, float, float, float] category_id: int @classmethod def compress_rle( cls, segmentation: Union[List[List[float]], UncompressedRLE], iscrowd: bool, height: int, width: int, ) -> CompressedRLE: # noqa: D101,D102,D103,D107 if iscrowd: rle = cocomask.frPyObjects(segmentation, h=height, w=width) else: rles = cocomask.frPyObjects(segmentation, h=height, w=width) rle = cocomask.merge(rles) return rle # type: ignore @classmethod def rle_segmentation_to_binary_mask( cls, segmentation, iscrowd: bool, height: int, width: int ) -> np.ndarray: # noqa: D101,D102,D103,D107 rle = cls.compress_rle( segmentation=segmentation, iscrowd=iscrowd, height=height, width=width ) return cocomask.decode(rle) # type: ignore @classmethod def rle_segmentation_to_mask( cls, segmentation: Union[List[List[float]], UncompressedRLE], iscrowd: bool, height: int, width: int, ) -> np.ndarray: # noqa: D101,D102,D103,D107 binary_mask = cls.rle_segmentation_to_binary_mask( segmentation=segmentation, iscrowd=iscrowd, height=height, width=width ) return binary_mask * 255 @classmethod def from_dict( cls, json_dict: JsonDict, images: Dict[ImageId, ImageData], decode_rle: bool, ) -> "InstancesAnnotationData": # noqa: D101,D102,D103,D107 segmentation = json_dict["segmentation"] image_id = json_dict["image_id"] image_data = images[image_id] iscrowd = bool(json_dict["iscrowd"]) segmentation_mask = ( cls.rle_segmentation_to_mask( segmentation=segmentation, iscrowd=iscrowd, height=image_data.height, width=image_data.width, ) if decode_rle else cls.compress_rle( segmentation=segmentation, iscrowd=iscrowd, height=image_data.height, width=image_data.width, ) ) return cls( # # for AnnotationData # annotation_id=json_dict["id"], image_id=image_id, # # for InstancesAnnotationData # segmentation=segmentation_mask, # type: ignore area=json_dict["area"], iscrowd=iscrowd, bbox=json_dict["bbox"], category_id=json_dict["category_id"], ) @dataclass class PersonKeypoint(object): # noqa: D101,D102,D103,D107 x: int y: int v: int state: str @dataclass class PersonKeypointsAnnotationData(InstancesAnnotationData): # noqa: D101,D102,D103,D107 num_keypoints: int keypoints: List[PersonKeypoint] @classmethod def v_keypoint_to_state(cls, keypoint_v: int) -> str: # noqa: D101,D102,D103,D107 return KEYPOINT_STATE[keypoint_v] @classmethod def get_person_keypoints( cls, flatten_keypoints: List[int], num_keypoints: int ) -> List[PersonKeypoint]: # noqa: D101,D102,D103,D107 keypoints_x = flatten_keypoints[0::3] keypoints_y = flatten_keypoints[1::3] keypoints_v = flatten_keypoints[2::3] assert len(keypoints_x) == len(keypoints_y) == len(keypoints_v) keypoints = [ PersonKeypoint(x=x, y=y, v=v, state=cls.v_keypoint_to_state(v)) for x, y, v in zip(keypoints_x, keypoints_y, keypoints_v) ] assert len([kp for kp in keypoints if kp.state != "unknown"]) == num_keypoints return keypoints @classmethod def from_dict( cls, json_dict: JsonDict, images: Dict[ImageId, ImageData], decode_rle: bool, ) -> "PersonKeypointsAnnotationData": # noqa: D101,D102,D103,D107 segmentation = json_dict["segmentation"] image_id = json_dict["image_id"] image_data = images[image_id] iscrowd = bool(json_dict["iscrowd"]) segmentation_mask = ( cls.rle_segmentation_to_mask( segmentation=segmentation, iscrowd=iscrowd, height=image_data.height, width=image_data.width, ) if decode_rle else cls.compress_rle( segmentation=segmentation, iscrowd=iscrowd, height=image_data.height, width=image_data.width, ) ) flatten_keypoints = json_dict["keypoints"] num_keypoints = json_dict["num_keypoints"] keypoints = cls.get_person_keypoints(flatten_keypoints, num_keypoints) return cls( # # for AnnotationData # annotation_id=json_dict["id"], image_id=image_id, # # for InstancesAnnotationData # segmentation=segmentation_mask, # type: ignore area=json_dict["area"], iscrowd=iscrowd, bbox=json_dict["bbox"], category_id=json_dict["category_id"], # # PersonKeypointsAnnotationData # num_keypoints=num_keypoints, keypoints=keypoints, ) class LicenseDict(TypedDict): # noqa: D101,D102,D103,D107 license_id: LicenseId name: str url: str class BaseExample(TypedDict): # noqa: D101,D102,D103,D107 image_id: ImageId image: PilImage file_name: str coco_url: str height: int width: int date_captured: str flickr_url: str license_id: LicenseId license: LicenseDict class CaptionAnnotationDict(TypedDict): # noqa: D101,D102,D103,D107 annotation_id: AnnotationId caption: str class CaptionExample(BaseExample): # noqa: D101,D102,D103,D107 annotations: List[CaptionAnnotationDict] class CategoryDict(TypedDict): # noqa: D101,D102,D103,D107 category_id: CategoryId name: str supercategory: str class InstanceAnnotationDict(TypedDict): # noqa: D101,D102,D103,D107 annotation_id: AnnotationId area: float bbox: Bbox image_id: ImageId category_id: CategoryId category: CategoryDict iscrowd: bool segmentation: np.ndarray class InstanceExample(BaseExample): # noqa: D101,D102,D103,D107 annotations: List[InstanceAnnotationDict] class KeypointDict(TypedDict): # noqa: D101,D102,D103,D107 x: int y: int v: int state: str class PersonKeypointAnnotationDict(InstanceAnnotationDict): # noqa: D101,D102,D103,D107 num_keypoints: int keypoints: List[KeypointDict] class PersonKeypointExample(BaseExample): # noqa: D101,D102,D103,D107 annotations: List[PersonKeypointAnnotationDict] class MsCocoProcessor(object, metaclass=abc.ABCMeta): # noqa: D101,D102,D103,D107 def load_image(self, image_path: str) -> PilImage: # noqa: D101,D102,D103,D107 return Image.open(image_path) def load_annotation_json(self, ann_file_path: str) -> JsonDict: # noqa: D101,D102,D103,D107 logger.info(f"Load annotation json from {ann_file_path}") with open(ann_file_path, "r") as rf: ann_json = json.load(rf) return ann_json def load_licenses_data( self, license_dicts: List[JsonDict] ) -> Dict[LicenseId, LicenseData]: # noqa: D101,D102,D103,D107 licenses = {} for license_dict in license_dicts: license_data = LicenseData.from_dict(license_dict) licenses[license_data.license_id] = license_data return licenses def load_images_data( self, image_dicts: List[JsonDict], tqdm_desc: str = "Load images", ) -> Dict[ImageId, ImageData]: # noqa: D101,D102,D103,D107 images = {} for image_dict in tqdm(image_dicts, desc=tqdm_desc): image_data = ImageData.from_dict(image_dict) images[image_data.image_id] = image_data return images def load_categories_data( self, category_dicts: List[JsonDict], tqdm_desc: str = "Load categories", ) -> Dict[CategoryId, CategoryData]: # noqa: D101,D102,D103,D107 categories = {} for category_dict in tqdm(category_dicts, desc=tqdm_desc): category_data = CategoryData.from_dict(category_dict) categories[category_data.category_id] = category_data return categories def get_features_base_dict(self): # noqa: D101,D102,D103,D107 return { "image_id": ds.Value("int64"), "image": ds.Image(), "file_name": ds.Value("string"), "coco_url": ds.Value("string"), "height": ds.Value("int32"), "width": ds.Value("int32"), "date_captured": ds.Value("string"), "flickr_url": ds.Value("string"), "license_id": ds.Value("int32"), "license": { "url": ds.Value("string"), "license_id": ds.Value("int8"), "name": ds.Value("string"), }, } @abc.abstractmethod def get_features(self, *args, **kwargs) -> ds.Features: # noqa: D101,D102,D103,D107 raise NotImplementedError @abc.abstractmethod def load_data(self, ann_dicts: List[JsonDict], tqdm_desc: str = "", **kwargs): # noqa: D101,D102,D103,D107 assert tqdm_desc != "", "tqdm_desc must be provided." raise NotImplementedError @abc.abstractmethod def generate_examples( self, image_dir: str, images: Dict[ImageId, ImageData], annotations: Dict[ImageId, List[CaptionsAnnotationData]], licenses: Dict[LicenseId, LicenseData], **kwargs, ): # noqa: D101,D102,D103,D107 raise NotImplementedError class CaptionsProcessor(MsCocoProcessor): # noqa: D101,D102,D103,D107 def get_features(self, *args, **kwargs) -> ds.Features: # noqa: D101,D102,D103,D107 features_dict = self.get_features_base_dict() annotations = ds.Sequence( { "annotation_id": ds.Value("int64"), "image_id": ds.Value("int64"), "caption": ds.Value("string"), } ) features_dict.update({"annotations": annotations}) return ds.Features(features_dict) def load_data( self, ann_dicts: List[JsonDict], tqdm_desc: str = "Load captions data", **kwargs, ) -> Dict[ImageId, List[CaptionsAnnotationData]]: # noqa: D101,D102,D103,D107 annotations = defaultdict(list) for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): ann_data = CaptionsAnnotationData.from_dict(ann_dict) annotations[ann_data.image_id].append(ann_data) return annotations def generate_examples( self, image_dir: str, images: Dict[ImageId, ImageData], annotations: Dict[ImageId, List[CaptionsAnnotationData]], licenses: Dict[LicenseId, LicenseData], **kwargs, ) -> Iterator[Tuple[int, CaptionExample]]: # noqa: D101,D102,D103,D107 for idx, image_id in enumerate(images.keys()): image_data = images[image_id] image_anns = annotations[image_id] assert len(image_anns) > 0 image = self.load_image( image_path=os.path.join(image_dir, image_data.file_name), ) example = asdict(image_data) example["image"] = image example["license"] = asdict(licenses[image_data.license_id]) example["annotations"] = [] for ann in image_anns: example["annotations"].append(asdict(ann)) yield idx, example # type: ignore class InstancesProcessor(MsCocoProcessor): # noqa: D101,D102,D103,D107 def get_features_instance_dict(self, decode_rle: bool): # noqa: D101,D102,D103,D107 segmentation_feature = ( ds.Image() if decode_rle else { "counts": ds.Sequence(ds.Value("int64")), "size": ds.Sequence(ds.Value("int32")), } ) return { "annotation_id": ds.Value("int64"), "image_id": ds.Value("int64"), "segmentation": segmentation_feature, "area": ds.Value("float32"), "iscrowd": ds.Value("bool"), "bbox": ds.Sequence(ds.Value("float32"), length=4), "category_id": ds.Value("int32"), "category": { "category_id": ds.Value("int32"), "name": ds.ClassLabel( num_classes=len(CATEGORIES), names=CATEGORIES, ), "supercategory": ds.ClassLabel( num_classes=len(SUPER_CATEGORIES), names=SUPER_CATEGORIES, ), }, } def get_features(self, decode_rle: bool) -> ds.Features: # noqa: D101,D102,D103,D107 features_dict = self.get_features_base_dict() annotations = ds.Sequence( self.get_features_instance_dict(decode_rle=decode_rle) ) features_dict.update({"annotations": annotations}) return ds.Features(features_dict) def load_data( # type: ignore[override] self, ann_dicts: List[JsonDict], images: Dict[ImageId, ImageData], decode_rle: bool, tqdm_desc: str = "Load instances data", ) -> Dict[ImageId, List[InstancesAnnotationData]]: # noqa: D101,D102,D103,D107 annotations = defaultdict(list) ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): ann_data = InstancesAnnotationData.from_dict( ann_dict, images=images, decode_rle=decode_rle ) annotations[ann_data.image_id].append(ann_data) return annotations def generate_examples( # type: ignore[override] self, image_dir: str, images: Dict[ImageId, ImageData], annotations: Dict[ImageId, List[InstancesAnnotationData]], licenses: Dict[LicenseId, LicenseData], categories: Dict[CategoryId, CategoryData], ) -> Iterator[Tuple[int, InstanceExample]]: # noqa: D101,D102,D103,D107 for idx, image_id in enumerate(images.keys()): image_data = images[image_id] image_anns = annotations[image_id] if len(image_anns) < 1: logger.warning(f"No annotation found for image id: {image_id}.") continue image = self.load_image( image_path=os.path.join(image_dir, image_data.file_name), ) example = asdict(image_data) example["image"] = image example["license"] = asdict(licenses[image_data.license_id]) example["annotations"] = [] for ann in image_anns: ann_dict = asdict(ann) category = categories[ann.category_id] ann_dict["category"] = asdict(category) example["annotations"].append(ann_dict) yield idx, example # type: ignore class PersonKeypointsProcessor(InstancesProcessor): # noqa: D101,D102,D103,D107 def get_features(self, decode_rle: bool) -> ds.Features: # noqa: D101,D102,D103,D107 features_dict = self.get_features_base_dict() features_instance_dict = self.get_features_instance_dict(decode_rle=decode_rle) features_instance_dict.update( { "keypoints": ds.Sequence( { "state": ds.Value("string"), "x": ds.Value("int32"), "y": ds.Value("int32"), "v": ds.Value("int32"), } ), "num_keypoints": ds.Value("int32"), } ) annotations = ds.Sequence(features_instance_dict) features_dict.update({"annotations": annotations}) return ds.Features(features_dict) def load_data( # type: ignore[override] self, ann_dicts: List[JsonDict], images: Dict[ImageId, ImageData], decode_rle: bool, tqdm_desc: str = "Load person keypoints data", ) -> Dict[ImageId, List[PersonKeypointsAnnotationData]]: # noqa: D101,D102,D103,D107 annotations = defaultdict(list) ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"]) for ann_dict in tqdm(ann_dicts, desc=tqdm_desc): ann_data = PersonKeypointsAnnotationData.from_dict( ann_dict, images=images, decode_rle=decode_rle ) annotations[ann_data.image_id].append(ann_data) return annotations def generate_examples( # type: ignore[override] self, image_dir: str, images: Dict[ImageId, ImageData], annotations: Dict[ImageId, List[PersonKeypointsAnnotationData]], licenses: Dict[LicenseId, LicenseData], categories: Dict[CategoryId, CategoryData], ) -> Iterator[Tuple[int, PersonKeypointExample]]: # noqa: D101,D102,D103,D107 for idx, image_id in enumerate(images.keys()): image_data = images[image_id] image_anns = annotations[image_id] if len(image_anns) < 1: # If there are no persons in the image, # no keypoint annotations will be assigned. continue image = self.load_image( image_path=os.path.join(image_dir, image_data.file_name), ) example = asdict(image_data) example["image"] = image example["license"] = asdict(licenses[image_data.license_id]) example["annotations"] = [] for ann in image_anns: ann_dict = asdict(ann) category = categories[ann.category_id] ann_dict["category"] = asdict(category) example["annotations"].append(ann_dict) yield idx, example # type: ignore class MsCocoConfig(ds.BuilderConfig): # noqa: D101,D102,D103,D107 YEARS: Tuple[int, ...] = ( 2014, 2017, ) TASKS: Tuple[str, ...] = ( "captions", "instances", "person_keypoints", ) def __init__( self, year: int, coco_task: Union[str, Sequence[str]], version: Optional[Union[ds.Version, str]], decode_rle: bool = False, data_dir: Optional[str] = None, data_files: Optional[DataFilesDict] = None, description: Optional[str] = None, ) -> None: # noqa: D101,D102,D103,D107 super().__init__( name=self.config_name(year=year, task=coco_task), version=version, data_dir=data_dir, data_files=data_files, description=description, ) self._check_year(year) self._check_task(coco_task) self._year = year self._task = coco_task self.processor = self.get_processor() self.decode_rle = decode_rle def _check_year(self, year: int) -> None: # noqa: D101,D102,D103,D107 assert year in self.YEARS, year def _check_task(self, task: Union[str, Sequence[str]]) -> None: # noqa: D101,D102,D103,D107 if isinstance(task, str): assert task in self.TASKS, task elif isinstance(task, list) or isinstance(task, tuple): for t in task: assert t, task else: raise ValueError(f"Invalid task: {task}") @property def year(self) -> int: # noqa: D101,D102,D103,D107 return self._year @property def task(self) -> str: # noqa: D101,D102,D103,D107 if isinstance(self._task, str): return self._task elif isinstance(self._task, list) or isinstance(self._task, tuple): return "-".join(sorted(self._task)) else: raise ValueError(f"Invalid task: {self._task}") def get_processor(self) -> MsCocoProcessor: # noqa: D101,D102,D103,D107 if self.task == "captions": return CaptionsProcessor() elif self.task == "instances": return InstancesProcessor() elif self.task == "person_keypoints": return PersonKeypointsProcessor() else: raise ValueError(f"Invalid task: {self.task}") @classmethod def config_name(cls, year: int, task: Union[str, Sequence[str]]) -> str: # noqa: D101,D102,D103,D107 if isinstance(task, str): return f"{year}-{task}" elif isinstance(task, list) or isinstance(task, tuple): task = "-".join(task) return f"{year}-{task}" else: raise ValueError(f"Invalid task: {task}") def dataset_configs(year: int, version: ds.Version) -> List[MsCocoConfig]: # noqa: D101,D102,D103,D107 return [ MsCocoConfig( year=year, coco_task="captions", version=version, ), MsCocoConfig( year=year, coco_task="instances", version=version, ), MsCocoConfig( year=year, coco_task="person_keypoints", version=version, ), # MsCocoConfig( # year=year, # coco_task=("captions", "instances"), # version=version, # ), # MsCocoConfig( # year=year, # coco_task=("captions", "person_keypoints"), # version=version, # ), ] def configs_2014(version: ds.Version) -> List[MsCocoConfig]: # noqa: D101,D102,D103,D107 return dataset_configs(year=2014, version=version) def configs_2017(version: ds.Version) -> List[MsCocoConfig]: # noqa: D101,D102,D103,D107 return dataset_configs(year=2017, version=version) class MsCocoDataset(ds.GeneratorBasedBuilder): # noqa: D101,D102,D103,D107 VERSION = ds.Version("1.0.0") BUILDER_CONFIG_CLASS = MsCocoConfig BUILDER_CONFIGS = configs_2014(version=VERSION) + configs_2017(version=VERSION) @property def year(self) -> int: # noqa: D101,D102,D103,D107 config: MsCocoConfig = self.config # type: ignore return config.year @property def task(self) -> str: # noqa: D101,D102,D103,D107 config: MsCocoConfig = self.config # type: ignore return config.task def _info(self) -> ds.DatasetInfo: # noqa: D101,D102,D103,D107 processor: MsCocoProcessor = self.config.processor # type: ignore features = processor.get_features(decode_rle=self.config.decode_rle) # type: ignore return ds.DatasetInfo( description=_DESCRIPTION, citation=_CITATION, homepage=_HOMEPAGE, license=_LICENSE, features=features, ) def _split_generators(self, dl_manager: ds.DownloadManager): # noqa: D101,D102,D103,D107 file_paths = dl_manager.download_and_extract(_URLS[f"{self.year}"]) imgs = file_paths["images"] # type: ignore anns = file_paths["annotations"] # type: ignore return [ # ds.SplitGenerator( # name=ds.Split.TRAIN, # type: ignore # gen_kwargs={ # "base_image_dir": imgs["train"], # "base_annotation_dir": anns["train_validation"], # "split": "train", # }, # ), ds.SplitGenerator( name=ds.Split.VALIDATION, # type: ignore gen_kwargs={ "base_image_dir": imgs["validation"], "base_annotation_dir": anns["train_validation"], "split": "val", }, ), # ds.SplitGenerator( # name=ds.Split.TEST, # type: ignore # gen_kwargs={ # "base_image_dir": imgs["test"], # "test_image_info_path": anns["test_image_info"], # "split": "test", # }, # ), ] def _generate_train_val_examples( self, split: str, base_image_dir: str, base_annotation_dir: str ): # noqa: D101,D102,D103,D107 image_dir = os.path.join(base_image_dir, f"{split}{self.year}") ann_dir = os.path.join(base_annotation_dir, "annotations") ann_file_path = os.path.join(ann_dir, f"{self.task}_{split}{self.year}.json") processor: MsCocoProcessor = self.config.processor # type: ignore ann_json = processor.load_annotation_json(ann_file_path=ann_file_path) # info = AnnotationInfo.from_dict(ann_json["info"]) licenses = processor.load_licenses_data(license_dicts=ann_json["licenses"]) images = processor.load_images_data(image_dicts=ann_json["images"]) category_dicts = ann_json.get("categories") categories = ( processor.load_categories_data(category_dicts=category_dicts) if category_dicts is not None else None ) config: MsCocoConfig = self.config # type: ignore yield from processor.generate_examples( annotations=processor.load_data( ann_dicts=ann_json["annotations"], images=images, decode_rle=config.decode_rle, ), categories=categories, image_dir=image_dir, images=images, licenses=licenses, ) def _generate_test_examples(self, test_image_info_path: str): # noqa: D101,D102,D103,D107 raise NotImplementedError def _generate_examples( self, split: MscocoSplits, base_image_dir: Optional[str] = None, base_annotation_dir: Optional[str] = None, test_image_info_path: Optional[str] = None, ): # noqa: D101,D102,D103,D107 if split == "test" and test_image_info_path is not None: yield from self._generate_test_examples( test_image_info_path=test_image_info_path ) elif ( split in get_args(MscocoSplits) and base_image_dir is not None and base_annotation_dir is not None ): yield from self._generate_train_val_examples( split=split, base_image_dir=base_image_dir, base_annotation_dir=base_annotation_dir, ) else: raise ValueError( f"Invalid arguments: split = {split}, " f"base_image_dir = {base_image_dir}, " f"base_annotation_dir = {base_annotation_dir}, " f"test_image_info_path = {test_image_info_path}", )