vision/m4/sourcing/pmd/loader_builder.py (276 lines of code) (raw):

import json from abc import ABC, abstractmethod from functools import partial from multiprocessing import get_context from pathlib import Path from typing import Dict, Iterator, List from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset from m4.sourcing.pmd import _FEATURES, get_m4_cache_dir from m4.sourcing.pmd.helpers import ( M4HTTPClient, PickableMediaDownloadGenerator, collapse_columns_to_meta, fetch_single_image, ) # ---- Base loader builder ----- class BaseLoaderBuilder(ABC): def __init__(self, split, num_proc: int, batch_size: int = 1000): self.split = split self.num_proc = num_proc self.batch_size = batch_size @abstractmethod def _load_dataset(self) -> Dataset: raise NotImplementedError() @abstractmethod def update_before_collapsing_meta(self, batch: Dict) -> Dict: raise NotImplementedError @property @abstractmethod def _DATASETS_NAME(self) -> str: raise NotImplementedError def _normalise(self, batch: Dict) -> Dict: """ Create the `text` field, typically from the field `caption` and remove the `caption` column. Remove all the un-necessary columns and put them into a json dict (`meta` column). """ # `datasets.map` requires function to return pure-functions, which is not the case here # https://github.com/huggingface/datasets/pull/4197#issue-1211342558 batch = batch.copy() batch = self.update_before_collapsing_meta(batch) # Collapse columns to meta batch = collapse_columns_to_meta( batch, columns_to_collapse=self.get_batch_metadata_columns(batch), meta_column_name="meta" ) # add `source` batch["source"] = [self._DATASETS_NAME for _ in batch["text"]] return batch @staticmethod def get_batch_metadata_columns(batch: Dict) -> List[str]: return list(set(batch.keys()) - {"image", "text"}) def build(self): dset = self._load_dataset() if isinstance(dset, Dataset): dset = dset.map( self._normalise, batched=True, remove_columns=dset.column_names, features=_FEATURES, num_proc=self.num_proc, batch_size=self.batch_size, ) # Make sure the features match the expected API. assert dset.features == _FEATURES, f"Got: {dset.features}, expected: {_FEATURES}" elif isinstance(dset, DatasetDict): new_dset = DatasetDict() for split, dset_split in dset.items(): new_dset[split] = dset_split.map( self._normalise, batched=True, remove_columns=dset_split.column_names, features=_FEATURES, num_proc=self.num_proc, batch_size=self.batch_size, ) # Make sure the features match the expected API. assert new_dset[split].features == _FEATURES, f"Got: {new_dset[split].features}, expected: {_FEATURES}" dset = new_dset return dset class DownloadImageLoaderBuilder(ABC): def __init__(self, split, num_proc: int, num_threads_per_proc: int, **http_client_kwargs): assert "cache_dir" not in http_client_kwargs self.http_client = M4HTTPClient( cache_dir=get_m4_cache_dir() / self._DATASETS_NAME / "downloaded_images", **http_client_kwargs, user_agent="Googlebot-Image/1.0", ) self.num_threads_per_proc = num_threads_per_proc self.split = split self.num_proc = num_proc @abstractmethod def _load_dataset(self) -> Dataset: """Load the original dataset""" raise NotImplementedError() @property @abstractmethod def _DATASETS_NAME(self) -> str: raise NotImplementedError def get_image_urls(self, batch) -> List[str]: return batch["image_url"] def get_texts(self, batch) -> List[List[str]]: """A list, each element is a list of potential captions""" return batch["caption"] def pre_download_image_map(self, batch: Dict) -> Dict: # move `captions` to `text` if "text" in batch: return batch batch["text"] = self.get_texts(batch) return batch def _normalise(self, batch: Dict) -> Dict: """Create the `text` field, typically from the field `caption`.""" # `datasets.map` requires function to return pure-functions, which is not the case here # https://github.com/huggingface/datasets/pull/4197#issue-1211342558 batch = batch.copy() # Changes to do before downloading images: typically filtering. new_batch = self.pre_download_image_map(batch) return new_batch def _add_image_or_exception(self, batch: Dict, image_or_exception_iterator: Iterator) -> Dict: """Get the images from the iterator and put them in the batch dict. Remove all the un-necessary columns and put them into a json dict (`meta` column). Add the source info to the batch dict""" # `datasets.map` requires function to return pure-functions, which is not the case here # https://github.com/huggingface/datasets/pull/4197#issue-1211342558 batch = batch.copy() # get image or exception batch_size = len(next(iter(batch.values()))) images, exceptions = tuple(zip(*[next(image_or_exception_iterator) for _ in range(batch_size)])) # add them to batch batch["image_download_exception"] = list(exceptions) batch["image"] = [image if image is not None else None for image in images] # Collapse columns to meta batch = collapse_columns_to_meta( batch, columns_to_collapse=self.get_batch_metadata_columns(batch), meta_column_name="meta" ) batch["source"] = [self._DATASETS_NAME for _ in range(batch_size)] return batch def map_shard(self, shard: Dataset) -> Dataset: """ Prepare the `text` fields, and download (or fetch from cache) images. """ # Decide which urls we need to query shard = shard.map( self._normalise, batched=True, remove_columns=shard.column_names, num_proc=1, # This is handled manually ) # actually download the image. with PickableMediaDownloadGenerator( download_media_url=partial(fetch_single_image, http_client=self.http_client), get_media_urls=self.get_image_urls, dset=shard, batch_size=1000, num_threads_per_proc=self.num_threads_per_proc, ) as image_iterator: # fill new dataset with that image_path of exception shard = shard.map( partial(self._add_image_or_exception, image_or_exception_iterator=image_iterator), batched=True, remove_columns=shard.column_names, features=_FEATURES, num_proc=1, # SUPER IMPORTANT as `PickableMediaDownloadGenerator` is stateful ) return shard @staticmethod def get_batch_metadata_columns(batch: Dict) -> List[str]: return list(set(batch.keys()) - {"image", "text"}) def build(self): dset = self._load_dataset() if isinstance(dset, Dataset): shards = [ dset.shard(num_shards=self.num_proc, index=rank, contiguous=True) for rank in range(self.num_proc) ] with get_context("spawn").Pool(self.num_proc) as pool: results = [pool.apply_async(self.map_shard, kwds={"shard": shard}) for shard in shards] transformed_shards = [result.get() for result in results] pool.terminate() pool.join() del pool dset = concatenate_datasets(transformed_shards) # Make sure the features match the expected API. assert dset.features == _FEATURES, f"Got: {dset.features}, expected: {_FEATURES}" elif isinstance(dset, DatasetDict): new_dset = DatasetDict() for split, dset_split in dset.items(): shards = [ dset_split.shard(num_shards=self.num_proc, index=rank, contiguous=True) for rank in range(self.num_proc) ] with get_context("spawn").Pool(self.num_proc) as pool: results = [pool.apply_async(self.map_shard, kwds={"shard": shard}) for shard in shards] transformed_shards = [result.get() for result in results] pool.terminate() pool.join() del pool new_dset[split] = concatenate_datasets(transformed_shards) # Make sure the features match the expected API. assert new_dset[split].features == _FEATURES, f"Got: {new_dset[split].features}, expected: {_FEATURES}" dset = new_dset return dset # ---- Dataset specific loader builder ----- class COCOLoaderBuilder(BaseLoaderBuilder, ABC): _DATASETS_NAME = "coco" def _load_dataset(self): return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'coco'}", split=self.split, use_auth_token=True, ) def update_before_collapsing_meta(self, batch: Dict) -> Dict: # move `caption` to `text` batch["text"] = [sents["raw"] for sents in batch["sentences"]] return batch class SBUCaptionsLoaderBuilder(DownloadImageLoaderBuilder): _DATASETS_NAME = "sbu_captions" def _load_dataset(self) -> Dataset: return load_dataset(self._DATASETS_NAME, split=self.split) class LocalizedNarrativesOpenImagesLoaderBuilder(DownloadImageLoaderBuilder): _DATASETS_NAME = "localized_narratives__openimages" def _load_dataset(self) -> Dataset: return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'localized_narratives__openimages'}", split=self.split, use_auth_token=True, ) class LocalizedNarrativesCOCOLoaderBuilder(BaseLoaderBuilder, ABC): _DATASETS_NAME = "localized_narratives__coco" def _load_dataset(self) -> Dataset: return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'localized_narratives__coco'}", split=self.split, use_auth_token=True, ) def update_before_collapsing_meta(self, batch: Dict) -> Dict: # move `caption` to `text` batch["text"] = batch["caption"] del batch["caption"] return batch class LocalizedNarrativesFlickr30kLoaderBuilder(BaseLoaderBuilder, ABC): _DATASETS_NAME = "localized_narratives__flickr30k" def _load_dataset(self) -> Dataset: return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'localized_narratives__flickr30k'}", data_dir=get_m4_cache_dir() / "flickr30k", split=self.split, use_auth_token=True, ) def update_before_collapsing_meta(self, batch: Dict) -> Dict: # move `caption` to `text` batch["text"] = batch["caption"] del batch["caption"] return batch class LocalizedNarrativesADE20kLoaderBuilder(BaseLoaderBuilder, ABC): _DATASETS_NAME = "localized_narratives__ADE20k" def _load_dataset(self) -> Dataset: return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'localized_narratives__ADE20k'}", split=self.split, use_auth_token=True, ) def update_before_collapsing_meta(self, batch: Dict) -> Dict: # move `caption` to `text` batch["text"] = batch["caption"] del batch["caption"] return batch class VisualGenomeLoaderBuilder(BaseLoaderBuilder, ABC): _DATASETS_NAME = "visual_genome" def _load_dataset(self) -> Dataset: # Victor - Dirty as fuck loading the karpathy splits. but i am tired of this image pmd and i want to go fast karpathy_coco_file = get_m4_cache_dir() / "coco-captions" / "dataset_coco.json" with open(karpathy_coco_file, "r", encoding="utf-8") as f: annotations = json.load(f) invalid_images = {} for annotation in annotations["images"]: if annotation["split"] == "test" or annotation["split"] == "val": invalid_images[int(annotation["cocoid"])] = 1 self.invalid_images = invalid_images return load_dataset(self._DATASETS_NAME, "region_descriptions_v1.2.0", split=self.split) def update_before_collapsing_meta(self, batch: Dict) -> Dict: metadata_columns = self.get_batch_metadata_columns(batch) new_batch = { **{column: [] for column in metadata_columns}, "image": [], "text": [], } for image, regions, coco_id, *values in zip( batch["image"], batch["regions"], batch["coco_id"], *[batch[column] for column in metadata_columns] ): if coco_id is None or int(coco_id) not in self.invalid_images: slices = [ (region["x"], region["x"] + region["width"], region["y"], region["y"] + region["height"]) for region in regions ] new_batch["image"] += [ image.crop((x_start, y_start, x_end, y_end)) for (x_start, x_end, y_start, y_end) in slices ] new_batch["text"] += [region["phrase"] for region in regions] for column, value in zip(metadata_columns, values): new_batch[column] += [value for _ in regions] return new_batch class Conceptual12MLoaderBuilder(DownloadImageLoaderBuilder): _DATASETS_NAME = "conceptual_12m" def _load_dataset(self) -> Dataset: return load_dataset(self._DATASETS_NAME, split=self.split) class RedCapsLoaderBuilder(DownloadImageLoaderBuilder): _DATASETS_NAME = "red_caps" def _load_dataset(self) -> Dataset: return load_dataset(self._DATASETS_NAME, "all", split=self.split) def get_texts(self, batch: Dict) -> List[List[str]]: return batch["raw_caption"] class YFCC100MLoaderBuilder(DownloadImageLoaderBuilder): _DATASETS_NAME = "yfcc100m" def _load_dataset(self) -> Dataset: return load_dataset( f"{Path(__file__).parent / 'local_loaders' / 'yfcc100m'}", split=self.split, use_auth_token=True, )