# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data Loading Script for MSCOCO.

Adapted from https://github.com/shunk031/huggingface-datasets_MSCOCO/blob/main/MSCOCO.py to only download validation
images.
"""

import abc
import json
import logging
import os
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import (
    Any,
    Dict,
    Final,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    TypedDict,
    Union,
    get_args,
)

import datasets as ds
import numpy as np
from datasets.data_files import DataFilesDict
from PIL import Image
from PIL.Image import Image as PilImage
from pycocotools import mask as cocomask
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)

JsonDict = Dict[str, Any]
ImageId = int
AnnotationId = int
LicenseId = int
CategoryId = int
Bbox = Tuple[float, float, float, float]

MscocoSplits = Literal["train", "val", "test"]

KEYPOINT_STATE: Final[List[str]] = ["unknown", "invisible", "visible"]


_CITATION = """
"""

_DESCRIPTION = """
"""

_HOMEPAGE = """
"""

_LICENSE = "https://creativecommons.org/licenses/by/4.0/legalcode"

_URLS = {
    "2014": {
        "images": {
            # "train": "http://images.cocodataset.org/zips/train2014.zip",
            "validation": "http://images.cocodataset.org/zips/val2014.zip",
            # "test": "http://images.cocodataset.org/zips/test2014.zip",
        },
        "annotations": {
            "train_validation": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
            "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2014.zip",
        },
    },
    "2015": {
        "images": {
            "test": "http://images.cocodataset.org/zips/test2015.zip",
        },
        "annotations": {
            "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2015.zip",
        },
    },
    "2017": {
        "images": {
            # "train": "http://images.cocodataset.org/zips/train2017.zip",
            "validation": "http://images.cocodataset.org/zips/val2017.zip",
            # "test": "http://images.cocodataset.org/zips/test2017.zip",
            # "unlabeled": "http://images.cocodataset.org/zips/unlabeled2017.zip",
        },
        "annotations": {
            "train_validation": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
            "stuff_train_validation": "http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip",
            "panoptic_train_validation": (
                "http://images.cocodataset.org/annotations/"
                "panoptic_annotations_trainval2017.zip"
            ),
            "test_image_info": "http://images.cocodataset.org/annotations/image_info_test2017.zip",
            "unlabeled": "http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip",
        },
    },
}

CATEGORIES: Final[List[str]] = [
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "backpack",
    "umbrella",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "dining table",
    "toilet",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
]

SUPER_CATEGORIES: Final[List[str]] = [
    "person",
    "vehicle",
    "outdoor",
    "animal",
    "accessory",
    "sports",
    "kitchen",
    "food",
    "furniture",
    "electronic",
    "appliance",
    "indoor",
]


@dataclass
class AnnotationInfo(object):  # noqa: D101,D102,D103,D107
    description: str
    url: str
    version: str
    year: str
    contributor: str
    date_created: str

    @classmethod
    def from_dict(cls, json_dict: JsonDict) -> "AnnotationInfo":  # noqa: D101,D102,D103,D107
        return cls(**json_dict)


@dataclass
class LicenseData(object):  # noqa: D101,D102,D103,D107
    url: str
    license_id: LicenseId
    name: str

    @classmethod
    def from_dict(cls, json_dict: JsonDict) -> "LicenseData":  # noqa: D101,D102,D103,D107
        return cls(
            license_id=json_dict["id"],
            url=json_dict["url"],
            name=json_dict["name"],
        )


@dataclass
class ImageData(object):  # noqa: D101,D102,D103,D107
    image_id: ImageId
    license_id: LicenseId
    file_name: str
    coco_url: str
    height: int
    width: int
    date_captured: str
    flickr_url: str

    @classmethod
    def from_dict(cls, json_dict: JsonDict) -> "ImageData":  # noqa: D101,D102,D103,D107
        return cls(
            image_id=json_dict["id"],
            license_id=json_dict["license"],
            file_name=json_dict["file_name"],
            coco_url=json_dict["coco_url"],
            height=json_dict["height"],
            width=json_dict["width"],
            date_captured=json_dict["date_captured"],
            flickr_url=json_dict["flickr_url"],
        )

    @property
    def shape(self) -> Tuple[int, int]:  # noqa: D101,D102,D103,D107
        return (self.height, self.width)


@dataclass
class CategoryData(object):  # noqa: D101,D102,D103,D107
    category_id: int
    name: str
    supercategory: str

    @classmethod
    def from_dict(cls, json_dict: JsonDict) -> "CategoryData":  # noqa: D101,D102,D103,D107
        return cls(
            category_id=json_dict["id"],
            name=json_dict["name"],
            supercategory=json_dict["supercategory"],
        )


@dataclass
class AnnotationData(object):  # noqa: D101,D102,D103,D107
    annotation_id: AnnotationId
    image_id: ImageId


@dataclass
class CaptionsAnnotationData(AnnotationData):  # noqa: D101,D102,D103,D107
    caption: str

    @classmethod
    def from_dict(cls, json_dict: JsonDict) -> "CaptionsAnnotationData":  # noqa: D101,D102,D103,D107
        return cls(
            annotation_id=json_dict["id"],
            image_id=json_dict["image_id"],
            caption=json_dict["caption"],
        )


class UncompressedRLE(TypedDict):  # noqa: D101,D102,D103,D107
    counts: List[int]
    size: Tuple[int, int]


class CompressedRLE(TypedDict):  # noqa: D101,D102,D103,D107
    counts: bytes
    size: Tuple[int, int]


@dataclass
class InstancesAnnotationData(AnnotationData):  # noqa: D101,D102,D103,D107
    segmentation: Union[np.ndarray, CompressedRLE]
    area: float
    iscrowd: bool
    bbox: Tuple[float, float, float, float]
    category_id: int

    @classmethod
    def compress_rle(
        cls,
        segmentation: Union[List[List[float]], UncompressedRLE],
        iscrowd: bool,
        height: int,
        width: int,
    ) -> CompressedRLE:  # noqa: D101,D102,D103,D107
        if iscrowd:
            rle = cocomask.frPyObjects(segmentation, h=height, w=width)
        else:
            rles = cocomask.frPyObjects(segmentation, h=height, w=width)
            rle = cocomask.merge(rles)

        return rle  # type: ignore

    @classmethod
    def rle_segmentation_to_binary_mask(
        cls, segmentation, iscrowd: bool, height: int, width: int
    ) -> np.ndarray:  # noqa: D101,D102,D103,D107
        rle = cls.compress_rle(
            segmentation=segmentation, iscrowd=iscrowd, height=height, width=width
        )
        return cocomask.decode(rle)  # type: ignore

    @classmethod
    def rle_segmentation_to_mask(
        cls,
        segmentation: Union[List[List[float]], UncompressedRLE],
        iscrowd: bool,
        height: int,
        width: int,
    ) -> np.ndarray:  # noqa: D101,D102,D103,D107
        binary_mask = cls.rle_segmentation_to_binary_mask(
            segmentation=segmentation, iscrowd=iscrowd, height=height, width=width
        )
        return binary_mask * 255

    @classmethod
    def from_dict(
        cls,
        json_dict: JsonDict,
        images: Dict[ImageId, ImageData],
        decode_rle: bool,
    ) -> "InstancesAnnotationData":  # noqa: D101,D102,D103,D107
        segmentation = json_dict["segmentation"]
        image_id = json_dict["image_id"]
        image_data = images[image_id]
        iscrowd = bool(json_dict["iscrowd"])

        segmentation_mask = (
            cls.rle_segmentation_to_mask(
                segmentation=segmentation,
                iscrowd=iscrowd,
                height=image_data.height,
                width=image_data.width,
            )
            if decode_rle
            else cls.compress_rle(
                segmentation=segmentation,
                iscrowd=iscrowd,
                height=image_data.height,
                width=image_data.width,
            )
        )
        return cls(
            #
            # for AnnotationData
            #
            annotation_id=json_dict["id"],
            image_id=image_id,
            #
            # for InstancesAnnotationData
            #
            segmentation=segmentation_mask,  # type: ignore
            area=json_dict["area"],
            iscrowd=iscrowd,
            bbox=json_dict["bbox"],
            category_id=json_dict["category_id"],
        )


@dataclass
class PersonKeypoint(object):  # noqa: D101,D102,D103,D107
    x: int
    y: int
    v: int
    state: str


@dataclass
class PersonKeypointsAnnotationData(InstancesAnnotationData):  # noqa: D101,D102,D103,D107
    num_keypoints: int
    keypoints: List[PersonKeypoint]

    @classmethod
    def v_keypoint_to_state(cls, keypoint_v: int) -> str:  # noqa: D101,D102,D103,D107
        return KEYPOINT_STATE[keypoint_v]

    @classmethod
    def get_person_keypoints(
        cls, flatten_keypoints: List[int], num_keypoints: int
    ) -> List[PersonKeypoint]:  # noqa: D101,D102,D103,D107
        keypoints_x = flatten_keypoints[0::3]
        keypoints_y = flatten_keypoints[1::3]
        keypoints_v = flatten_keypoints[2::3]
        assert len(keypoints_x) == len(keypoints_y) == len(keypoints_v)

        keypoints = [
            PersonKeypoint(x=x, y=y, v=v, state=cls.v_keypoint_to_state(v))
            for x, y, v in zip(keypoints_x, keypoints_y, keypoints_v)
        ]
        assert len([kp for kp in keypoints if kp.state != "unknown"]) == num_keypoints
        return keypoints

    @classmethod
    def from_dict(
        cls,
        json_dict: JsonDict,
        images: Dict[ImageId, ImageData],
        decode_rle: bool,
    ) -> "PersonKeypointsAnnotationData":  # noqa: D101,D102,D103,D107
        segmentation = json_dict["segmentation"]
        image_id = json_dict["image_id"]
        image_data = images[image_id]
        iscrowd = bool(json_dict["iscrowd"])

        segmentation_mask = (
            cls.rle_segmentation_to_mask(
                segmentation=segmentation,
                iscrowd=iscrowd,
                height=image_data.height,
                width=image_data.width,
            )
            if decode_rle
            else cls.compress_rle(
                segmentation=segmentation,
                iscrowd=iscrowd,
                height=image_data.height,
                width=image_data.width,
            )
        )
        flatten_keypoints = json_dict["keypoints"]
        num_keypoints = json_dict["num_keypoints"]
        keypoints = cls.get_person_keypoints(flatten_keypoints, num_keypoints)

        return cls(
            #
            # for AnnotationData
            #
            annotation_id=json_dict["id"],
            image_id=image_id,
            #
            # for InstancesAnnotationData
            #
            segmentation=segmentation_mask,  # type: ignore
            area=json_dict["area"],
            iscrowd=iscrowd,
            bbox=json_dict["bbox"],
            category_id=json_dict["category_id"],
            #
            # PersonKeypointsAnnotationData
            #
            num_keypoints=num_keypoints,
            keypoints=keypoints,
        )


class LicenseDict(TypedDict):  # noqa: D101,D102,D103,D107
    license_id: LicenseId
    name: str
    url: str


class BaseExample(TypedDict):  # noqa: D101,D102,D103,D107
    image_id: ImageId
    image: PilImage
    file_name: str
    coco_url: str
    height: int
    width: int
    date_captured: str
    flickr_url: str
    license_id: LicenseId
    license: LicenseDict


class CaptionAnnotationDict(TypedDict):  # noqa: D101,D102,D103,D107
    annotation_id: AnnotationId
    caption: str


class CaptionExample(BaseExample):  # noqa: D101,D102,D103,D107
    annotations: List[CaptionAnnotationDict]


class CategoryDict(TypedDict):  # noqa: D101,D102,D103,D107
    category_id: CategoryId
    name: str
    supercategory: str


class InstanceAnnotationDict(TypedDict):  # noqa: D101,D102,D103,D107
    annotation_id: AnnotationId
    area: float
    bbox: Bbox
    image_id: ImageId
    category_id: CategoryId
    category: CategoryDict
    iscrowd: bool
    segmentation: np.ndarray


class InstanceExample(BaseExample):  # noqa: D101,D102,D103,D107
    annotations: List[InstanceAnnotationDict]


class KeypointDict(TypedDict):  # noqa: D101,D102,D103,D107
    x: int
    y: int
    v: int
    state: str


class PersonKeypointAnnotationDict(InstanceAnnotationDict):  # noqa: D101,D102,D103,D107
    num_keypoints: int
    keypoints: List[KeypointDict]


class PersonKeypointExample(BaseExample):  # noqa: D101,D102,D103,D107
    annotations: List[PersonKeypointAnnotationDict]


class MsCocoProcessor(object, metaclass=abc.ABCMeta):  # noqa: D101,D102,D103,D107
    def load_image(self, image_path: str) -> PilImage:  # noqa: D101,D102,D103,D107
        return Image.open(image_path)

    def load_annotation_json(self, ann_file_path: str) -> JsonDict:  # noqa: D101,D102,D103,D107
        logger.info(f"Load annotation json from {ann_file_path}")
        with open(ann_file_path, "r") as rf:
            ann_json = json.load(rf)
        return ann_json

    def load_licenses_data(
        self, license_dicts: List[JsonDict]
    ) -> Dict[LicenseId, LicenseData]:  # noqa: D101,D102,D103,D107
        licenses = {}
        for license_dict in license_dicts:
            license_data = LicenseData.from_dict(license_dict)
            licenses[license_data.license_id] = license_data
        return licenses

    def load_images_data(
        self,
        image_dicts: List[JsonDict],
        tqdm_desc: str = "Load images",
    ) -> Dict[ImageId, ImageData]:  # noqa: D101,D102,D103,D107
        images = {}
        for image_dict in tqdm(image_dicts, desc=tqdm_desc):
            image_data = ImageData.from_dict(image_dict)
            images[image_data.image_id] = image_data
        return images

    def load_categories_data(
        self,
        category_dicts: List[JsonDict],
        tqdm_desc: str = "Load categories",
    ) -> Dict[CategoryId, CategoryData]:  # noqa: D101,D102,D103,D107
        categories = {}
        for category_dict in tqdm(category_dicts, desc=tqdm_desc):
            category_data = CategoryData.from_dict(category_dict)
            categories[category_data.category_id] = category_data
        return categories

    def get_features_base_dict(self):  # noqa: D101,D102,D103,D107
        return {
            "image_id": ds.Value("int64"),
            "image": ds.Image(),
            "file_name": ds.Value("string"),
            "coco_url": ds.Value("string"),
            "height": ds.Value("int32"),
            "width": ds.Value("int32"),
            "date_captured": ds.Value("string"),
            "flickr_url": ds.Value("string"),
            "license_id": ds.Value("int32"),
            "license": {
                "url": ds.Value("string"),
                "license_id": ds.Value("int8"),
                "name": ds.Value("string"),
            },
        }

    @abc.abstractmethod
    def get_features(self, *args, **kwargs) -> ds.Features:  # noqa: D101,D102,D103,D107
        raise NotImplementedError

    @abc.abstractmethod
    def load_data(self, ann_dicts: List[JsonDict], tqdm_desc: str = "", **kwargs):  # noqa: D101,D102,D103,D107
        assert tqdm_desc != "", "tqdm_desc must be provided."
        raise NotImplementedError

    @abc.abstractmethod
    def generate_examples(
        self,
        image_dir: str,
        images: Dict[ImageId, ImageData],
        annotations: Dict[ImageId, List[CaptionsAnnotationData]],
        licenses: Dict[LicenseId, LicenseData],
        **kwargs,
    ):  # noqa: D101,D102,D103,D107
        raise NotImplementedError


class CaptionsProcessor(MsCocoProcessor):  # noqa: D101,D102,D103,D107
    def get_features(self, *args, **kwargs) -> ds.Features:  # noqa: D101,D102,D103,D107
        features_dict = self.get_features_base_dict()
        annotations = ds.Sequence(
            {
                "annotation_id": ds.Value("int64"),
                "image_id": ds.Value("int64"),
                "caption": ds.Value("string"),
            }
        )
        features_dict.update({"annotations": annotations})
        return ds.Features(features_dict)

    def load_data(
        self,
        ann_dicts: List[JsonDict],
        tqdm_desc: str = "Load captions data",
        **kwargs,
    ) -> Dict[ImageId, List[CaptionsAnnotationData]]:  # noqa: D101,D102,D103,D107
        annotations = defaultdict(list)
        for ann_dict in tqdm(ann_dicts, desc=tqdm_desc):
            ann_data = CaptionsAnnotationData.from_dict(ann_dict)
            annotations[ann_data.image_id].append(ann_data)
        return annotations

    def generate_examples(
        self,
        image_dir: str,
        images: Dict[ImageId, ImageData],
        annotations: Dict[ImageId, List[CaptionsAnnotationData]],
        licenses: Dict[LicenseId, LicenseData],
        **kwargs,
    ) -> Iterator[Tuple[int, CaptionExample]]:  # noqa: D101,D102,D103,D107
        for idx, image_id in enumerate(images.keys()):
            image_data = images[image_id]
            image_anns = annotations[image_id]

            assert len(image_anns) > 0

            image = self.load_image(
                image_path=os.path.join(image_dir, image_data.file_name),
            )
            example = asdict(image_data)
            example["image"] = image
            example["license"] = asdict(licenses[image_data.license_id])

            example["annotations"] = []
            for ann in image_anns:
                example["annotations"].append(asdict(ann))

            yield idx, example  # type: ignore


class InstancesProcessor(MsCocoProcessor):  # noqa: D101,D102,D103,D107
    def get_features_instance_dict(self, decode_rle: bool):  # noqa: D101,D102,D103,D107
        segmentation_feature = (
            ds.Image()
            if decode_rle
            else {
                "counts": ds.Sequence(ds.Value("int64")),
                "size": ds.Sequence(ds.Value("int32")),
            }
        )
        return {
            "annotation_id": ds.Value("int64"),
            "image_id": ds.Value("int64"),
            "segmentation": segmentation_feature,
            "area": ds.Value("float32"),
            "iscrowd": ds.Value("bool"),
            "bbox": ds.Sequence(ds.Value("float32"), length=4),
            "category_id": ds.Value("int32"),
            "category": {
                "category_id": ds.Value("int32"),
                "name": ds.ClassLabel(
                    num_classes=len(CATEGORIES),
                    names=CATEGORIES,
                ),
                "supercategory": ds.ClassLabel(
                    num_classes=len(SUPER_CATEGORIES),
                    names=SUPER_CATEGORIES,
                ),
            },
        }

    def get_features(self, decode_rle: bool) -> ds.Features:  # noqa: D101,D102,D103,D107
        features_dict = self.get_features_base_dict()
        annotations = ds.Sequence(
            self.get_features_instance_dict(decode_rle=decode_rle)
        )
        features_dict.update({"annotations": annotations})
        return ds.Features(features_dict)

    def load_data(  # type: ignore[override]
        self,
        ann_dicts: List[JsonDict],
        images: Dict[ImageId, ImageData],
        decode_rle: bool,
        tqdm_desc: str = "Load instances data",
    ) -> Dict[ImageId, List[InstancesAnnotationData]]:  # noqa: D101,D102,D103,D107
        annotations = defaultdict(list)
        ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"])

        for ann_dict in tqdm(ann_dicts, desc=tqdm_desc):
            ann_data = InstancesAnnotationData.from_dict(
                ann_dict, images=images, decode_rle=decode_rle
            )
            annotations[ann_data.image_id].append(ann_data)

        return annotations

    def generate_examples(  # type: ignore[override]
        self,
        image_dir: str,
        images: Dict[ImageId, ImageData],
        annotations: Dict[ImageId, List[InstancesAnnotationData]],
        licenses: Dict[LicenseId, LicenseData],
        categories: Dict[CategoryId, CategoryData],
    ) -> Iterator[Tuple[int, InstanceExample]]:  # noqa: D101,D102,D103,D107
        for idx, image_id in enumerate(images.keys()):
            image_data = images[image_id]
            image_anns = annotations[image_id]

            if len(image_anns) < 1:
                logger.warning(f"No annotation found for image id: {image_id}.")
                continue

            image = self.load_image(
                image_path=os.path.join(image_dir, image_data.file_name),
            )
            example = asdict(image_data)
            example["image"] = image
            example["license"] = asdict(licenses[image_data.license_id])

            example["annotations"] = []
            for ann in image_anns:
                ann_dict = asdict(ann)
                category = categories[ann.category_id]
                ann_dict["category"] = asdict(category)
                example["annotations"].append(ann_dict)

            yield idx, example  # type: ignore


class PersonKeypointsProcessor(InstancesProcessor):  # noqa: D101,D102,D103,D107
    def get_features(self, decode_rle: bool) -> ds.Features:  # noqa: D101,D102,D103,D107
        features_dict = self.get_features_base_dict()
        features_instance_dict = self.get_features_instance_dict(decode_rle=decode_rle)
        features_instance_dict.update(
            {
                "keypoints": ds.Sequence(
                    {
                        "state": ds.Value("string"),
                        "x": ds.Value("int32"),
                        "y": ds.Value("int32"),
                        "v": ds.Value("int32"),
                    }
                ),
                "num_keypoints": ds.Value("int32"),
            }
        )
        annotations = ds.Sequence(features_instance_dict)
        features_dict.update({"annotations": annotations})
        return ds.Features(features_dict)

    def load_data(  # type: ignore[override]
        self,
        ann_dicts: List[JsonDict],
        images: Dict[ImageId, ImageData],
        decode_rle: bool,
        tqdm_desc: str = "Load person keypoints data",
    ) -> Dict[ImageId, List[PersonKeypointsAnnotationData]]:  # noqa: D101,D102,D103,D107
        annotations = defaultdict(list)
        ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"])

        for ann_dict in tqdm(ann_dicts, desc=tqdm_desc):
            ann_data = PersonKeypointsAnnotationData.from_dict(
                ann_dict, images=images, decode_rle=decode_rle
            )
            annotations[ann_data.image_id].append(ann_data)
        return annotations

    def generate_examples(  # type: ignore[override]
        self,
        image_dir: str,
        images: Dict[ImageId, ImageData],
        annotations: Dict[ImageId, List[PersonKeypointsAnnotationData]],
        licenses: Dict[LicenseId, LicenseData],
        categories: Dict[CategoryId, CategoryData],
    ) -> Iterator[Tuple[int, PersonKeypointExample]]:  # noqa: D101,D102,D103,D107
        for idx, image_id in enumerate(images.keys()):
            image_data = images[image_id]
            image_anns = annotations[image_id]

            if len(image_anns) < 1:
                # If there are no persons in the image,
                # no keypoint annotations will be assigned.
                continue

            image = self.load_image(
                image_path=os.path.join(image_dir, image_data.file_name),
            )
            example = asdict(image_data)
            example["image"] = image
            example["license"] = asdict(licenses[image_data.license_id])

            example["annotations"] = []
            for ann in image_anns:
                ann_dict = asdict(ann)
                category = categories[ann.category_id]
                ann_dict["category"] = asdict(category)
                example["annotations"].append(ann_dict)

            yield idx, example  # type: ignore


class MsCocoConfig(ds.BuilderConfig):  # noqa: D101,D102,D103,D107
    YEARS: Tuple[int, ...] = (
        2014,
        2017,
    )
    TASKS: Tuple[str, ...] = (
        "captions",
        "instances",
        "person_keypoints",
    )

    def __init__(
        self,
        year: int,
        coco_task: Union[str, Sequence[str]],
        version: Optional[Union[ds.Version, str]],
        decode_rle: bool = False,
        data_dir: Optional[str] = None,
        data_files: Optional[DataFilesDict] = None,
        description: Optional[str] = None,
    ) -> None:  # noqa: D101,D102,D103,D107
        super().__init__(
            name=self.config_name(year=year, task=coco_task),
            version=version,
            data_dir=data_dir,
            data_files=data_files,
            description=description,
        )
        self._check_year(year)
        self._check_task(coco_task)

        self._year = year
        self._task = coco_task
        self.processor = self.get_processor()
        self.decode_rle = decode_rle

    def _check_year(self, year: int) -> None:  # noqa: D101,D102,D103,D107
        assert year in self.YEARS, year

    def _check_task(self, task: Union[str, Sequence[str]]) -> None:  # noqa: D101,D102,D103,D107
        if isinstance(task, str):
            assert task in self.TASKS, task
        elif isinstance(task, list) or isinstance(task, tuple):
            for t in task:
                assert t, task
        else:
            raise ValueError(f"Invalid task: {task}")

    @property
    def year(self) -> int:  # noqa: D101,D102,D103,D107
        return self._year

    @property
    def task(self) -> str:  # noqa: D101,D102,D103,D107
        if isinstance(self._task, str):
            return self._task
        elif isinstance(self._task, list) or isinstance(self._task, tuple):
            return "-".join(sorted(self._task))
        else:
            raise ValueError(f"Invalid task: {self._task}")

    def get_processor(self) -> MsCocoProcessor:  # noqa: D101,D102,D103,D107
        if self.task == "captions":
            return CaptionsProcessor()
        elif self.task == "instances":
            return InstancesProcessor()
        elif self.task == "person_keypoints":
            return PersonKeypointsProcessor()
        else:
            raise ValueError(f"Invalid task: {self.task}")

    @classmethod
    def config_name(cls, year: int, task: Union[str, Sequence[str]]) -> str:  # noqa: D101,D102,D103,D107
        if isinstance(task, str):
            return f"{year}-{task}"
        elif isinstance(task, list) or isinstance(task, tuple):
            task = "-".join(task)
            return f"{year}-{task}"
        else:
            raise ValueError(f"Invalid task: {task}")


def dataset_configs(year: int, version: ds.Version) -> List[MsCocoConfig]:  # noqa: D101,D102,D103,D107
    return [
        MsCocoConfig(
            year=year,
            coco_task="captions",
            version=version,
        ),
        MsCocoConfig(
            year=year,
            coco_task="instances",
            version=version,
        ),
        MsCocoConfig(
            year=year,
            coco_task="person_keypoints",
            version=version,
        ),
        # MsCocoConfig(
        #     year=year,
        #     coco_task=("captions", "instances"),
        #     version=version,
        # ),
        # MsCocoConfig(
        #     year=year,
        #     coco_task=("captions", "person_keypoints"),
        #     version=version,
        # ),
    ]


def configs_2014(version: ds.Version) -> List[MsCocoConfig]:  # noqa: D101,D102,D103,D107
    return dataset_configs(year=2014, version=version)


def configs_2017(version: ds.Version) -> List[MsCocoConfig]:  # noqa: D101,D102,D103,D107
    return dataset_configs(year=2017, version=version)


class MsCocoDataset(ds.GeneratorBasedBuilder):  # noqa: D101,D102,D103,D107
    VERSION = ds.Version("1.0.0")
    BUILDER_CONFIG_CLASS = MsCocoConfig
    BUILDER_CONFIGS = configs_2014(version=VERSION) + configs_2017(version=VERSION)

    @property
    def year(self) -> int:  # noqa: D101,D102,D103,D107
        config: MsCocoConfig = self.config  # type: ignore
        return config.year

    @property
    def task(self) -> str:  # noqa: D101,D102,D103,D107
        config: MsCocoConfig = self.config  # type: ignore
        return config.task

    def _info(self) -> ds.DatasetInfo:  # noqa: D101,D102,D103,D107
        processor: MsCocoProcessor = self.config.processor  # type: ignore
        features = processor.get_features(decode_rle=self.config.decode_rle)  # type: ignore
        return ds.DatasetInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            homepage=_HOMEPAGE,
            license=_LICENSE,
            features=features,
        )

    def _split_generators(self, dl_manager: ds.DownloadManager):  # noqa: D101,D102,D103,D107
        file_paths = dl_manager.download_and_extract(_URLS[f"{self.year}"])

        imgs = file_paths["images"]  # type: ignore
        anns = file_paths["annotations"]  # type: ignore

        return [
            # ds.SplitGenerator(
            #     name=ds.Split.TRAIN,  # type: ignore
            #     gen_kwargs={
            #         "base_image_dir": imgs["train"],
            #         "base_annotation_dir": anns["train_validation"],
            #         "split": "train",
            #     },
            # ),
            ds.SplitGenerator(
                name=ds.Split.VALIDATION,  # type: ignore
                gen_kwargs={
                    "base_image_dir": imgs["validation"],
                    "base_annotation_dir": anns["train_validation"],
                    "split": "val",
                },
            ),
            # ds.SplitGenerator(
            #     name=ds.Split.TEST,  # type: ignore
            #     gen_kwargs={
            #         "base_image_dir": imgs["test"],
            #         "test_image_info_path": anns["test_image_info"],
            #         "split": "test",
            #     },
            # ),
        ]

    def _generate_train_val_examples(
        self, split: str, base_image_dir: str, base_annotation_dir: str
    ):  # noqa: D101,D102,D103,D107
        image_dir = os.path.join(base_image_dir, f"{split}{self.year}")

        ann_dir = os.path.join(base_annotation_dir, "annotations")
        ann_file_path = os.path.join(ann_dir, f"{self.task}_{split}{self.year}.json")

        processor: MsCocoProcessor = self.config.processor  # type: ignore

        ann_json = processor.load_annotation_json(ann_file_path=ann_file_path)

        # info = AnnotationInfo.from_dict(ann_json["info"])
        licenses = processor.load_licenses_data(license_dicts=ann_json["licenses"])
        images = processor.load_images_data(image_dicts=ann_json["images"])

        category_dicts = ann_json.get("categories")
        categories = (
            processor.load_categories_data(category_dicts=category_dicts)
            if category_dicts is not None
            else None
        )

        config: MsCocoConfig = self.config  # type: ignore
        yield from processor.generate_examples(
            annotations=processor.load_data(
                ann_dicts=ann_json["annotations"],
                images=images,
                decode_rle=config.decode_rle,
            ),
            categories=categories,
            image_dir=image_dir,
            images=images,
            licenses=licenses,
        )

    def _generate_test_examples(self, test_image_info_path: str):  # noqa: D101,D102,D103,D107
        raise NotImplementedError

    def _generate_examples(
        self,
        split: MscocoSplits,
        base_image_dir: Optional[str] = None,
        base_annotation_dir: Optional[str] = None,
        test_image_info_path: Optional[str] = None,
    ):  # noqa: D101,D102,D103,D107
        if split == "test" and test_image_info_path is not None:
            yield from self._generate_test_examples(
                test_image_info_path=test_image_info_path
            )
        elif (
            split in get_args(MscocoSplits)
            and base_image_dir is not None
            and base_annotation_dir is not None
        ):
            yield from self._generate_train_val_examples(
                split=split,
                base_image_dir=base_image_dir,
                base_annotation_dir=base_annotation_dir,
            )
        else:
            raise ValueError(
                f"Invalid arguments: split = {split}, "
                f"base_image_dir = {base_image_dir}, "
                f"base_annotation_dir = {base_annotation_dir}, "
                f"test_image_info_path = {test_image_info_path}",
            )
