timm/data/readers/reader_hfids.py (172 lines of code) (raw):
""" Dataset reader for HF IterableDataset
"""
import math
import os
from itertools import repeat, chain
from typing import Optional
import torch
import torch.distributed as dist
from PIL import Image
try:
import datasets
from datasets.distributed import split_dataset_by_node
from datasets.splits import SplitInfo
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
raise e
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
class ReaderHfids(Reader):
def __init__(
self,
name: str,
root: Optional[str] = None,
split: str = 'train',
is_training: bool = False,
batch_size: int = 1,
download: bool = False,
repeats: int = 0,
seed: int = 42,
class_map: Optional[dict] = None,
input_key: str = 'image',
input_img_mode: str = 'RGB',
target_key: str = 'label',
target_img_mode: str = '',
shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None,
trust_remote_code: bool = False
):
super().__init__()
self.root = root
self.split = split
self.is_training = is_training
self.batch_size = batch_size
self.download = download
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.input_key = input_key
self.input_img_mode = input_img_mode
self.target_key = target_key
self.target_img_mode = target_img_mode
self.builder = datasets.load_dataset_builder(
name,
cache_dir=root,
trust_remote_code=trust_remote_code,
)
if download:
self.builder.download_and_prepare()
split_info: Optional[SplitInfo] = None
if self.builder.info.splits and split in self.builder.info.splits:
if isinstance(self.builder.info.splits[split], SplitInfo):
split_info: Optional[SplitInfo] = self.builder.info.splits[split]
if num_samples:
self.num_samples = num_samples
elif split_info and split_info.num_examples:
self.num_samples = split_info.num_examples
else:
raise ValueError(
"Dataset length is unknown, please pass `num_samples` explicitly. "
"The number of steps needs to be known in advance for the learning rate scheduler."
)
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_info = None
self.worker_id = 0
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
# Initialized lazily on each dataloader worker process
self.ds: Optional[datasets.IterableDataset] = None
self.epoch = SharedCount()
def set_epoch(self, count):
# to update the shuffling effective_seed = seed + epoch
self.epoch.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
if self.worker_info is None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_info = worker_info
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
if self.download:
dataset = self.builder.as_dataset(split=self.split)
# to distribute evenly to workers
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
else:
# in this case the number of shard is determined by the number of remote files
ds = self.builder.as_streaming_dataset(split=self.split)
if self.is_training:
# will shuffle the list of shards and use a shuffle buffer
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
# Distributed:
# The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
# so the shards are evenly assigned across the nodes.
# If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
# Workers:
# In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
def _num_samples_per_worker(self):
num_worker_samples = \
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self):
if self.ds is None:
self._lazy_init()
self.ds.set_epoch(self.epoch.value)
target_sample_count = self._num_samples_per_worker()
sample_count = 0
if self.is_training:
ds_iter = chain.from_iterable(repeat(self.ds))
else:
ds_iter = iter(self.ds)
for sample in ds_iter:
input_data: Image.Image = sample[self.input_key]
if self.input_img_mode and input_data.mode != self.input_img_mode:
input_data = input_data.convert(self.input_img_mode)
target_data = sample[self.target_key]
if self.target_img_mode:
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
if target_data.mode != self.target_img_mode:
target_data = target_data.convert(self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
break
def __len__(self):
num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
elif 'image_id' in sample:
name = sample['image_id']
else:
assert False, "No supported name field present"
names.append(name)
return names