vision/m4/training/dataset.py (1,086 lines of code) (raw):

""" This file defines the dataloader logic. """ import copy import inspect import logging import os from dataclasses import asdict from functools import partial from pathlib import Path from typing import Optional import datasets import numpy as np import torch import webdataset as wds from accelerate.state import AcceleratorState from PIL import Image, ImageFile from torch.utils.data import Sampler from m4.training.config import DataParams, DatasetParams, Parameters from m4.training.dataset_utils import check_webdataset_command, get_webdataset from m4.training.packing import ( split_pack_and_pad_iqa_finetuning, split_pack_and_pad_ocr, split_pack_and_pad_pairs, split_pack_and_pad_sft, split_pack_and_pad_webdocs, ) from m4.training.types import DatasetNames, DatasetTypes Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True logger = logging.getLogger(__name__) """ Possible dataloader/dataset nestings DataLoaderForIterableWrapperDataset CustomChainDataset IterableWrapperDataset DataLoaderForMappedWrapperDataset MapWrapperDataset """ # TODO(siddk) :: This file needs to be cleaned up a bit? """ dataset[idx]: { 'images':[ <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503595350>, None, None, None, None, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503595D90>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503595E90>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503595F90>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD5035930D0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503595D50>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256 at 0x7FD503593290>, ], 'texts':[ None, 'VALMAS COMMERCE', 'VALMAS SA imports top quality products and materials that meet the requirements of modern construction. The goal of our company is to provide material and technical know-how for reliable solutions in the field of construction, renovation and repair. Raw materials of high-quality for every particular need, which also remain environmentally sensitive. Investing continually in enhancing our infrastructure and with a steadily growing track record over the years, we are in a position to meet your every need. We will be glad to welcome you to our 500m² exhibition to choose from a range of quality materials the ones that best match your own requirements.', 'FLOORING – PAINTS\n\nQuality and great variety in paneling materials for classic and traditional applications of modern residential aesthetics.\n\nFLOORING: With knowledge and experience in construction, we are able to propose solutions that combine the available budget with the specific demands for functionality and aesthetics. Wooden floors, ceramic tiles, marble, stone, as well as construction flooring, such as colored or special cement mortars, offer classic or traditional choices, following the original and modern residential aesthetic result.\n\n·' 'Partners List', None, None, None, None, None, None, ] } """ def to_tensor(batch): for key in batch: if not torch.is_tensor(batch[key]): batch[key] = torch.tensor(batch[key]) return batch # We don't want collate_fn to add extra dim --> can't use a lambda because of `multiprocessing` pickle requirements! def simple_collate(x): return x[0] def get_mapper( tokenizer, image_transform, dataset_type: DatasetTypes, image_seq_len: int, max_seq_len: int = 256, max_num_images: int = 5, max_image_size: int = 384, vision_encoder_max_image_size: int = 384, pre_split_scale_up_max=1.0, pre_split_scale_up_frequency=0.0, is_t5: bool = False, pad_dataset: bool = True, max_num_samples_per_document: int = 1, t5_mlm_noise_density: Optional[float] = None, t5_mlm_mean_noise_span_length: Optional[int] = None, add_begin_of_doc_token: bool = True, add_end_of_doc_token: bool = True, max_num_images_per_document: Optional[int] = None, ): mapper_kwargs = { "tokenizer": tokenizer, "image_transform": image_transform, "max_seq_len": max_seq_len, "max_num_images": max_num_images, "max_image_size": max_image_size, "vision_encoder_max_image_size": vision_encoder_max_image_size, "pre_split_scale_up_max": pre_split_scale_up_max, "pre_split_scale_up_frequency": pre_split_scale_up_frequency, "image_seq_len": image_seq_len, "add_begin_of_doc_token": add_begin_of_doc_token, "add_end_of_doc_token": add_end_of_doc_token, } if not pad_dataset: raise ValueError("This feature has been deprecated. The dataset must be padded") if is_t5: mapper_kwargs["noise_density"] = t5_mlm_noise_density mapper_kwargs["mean_noise_span_length"] = t5_mlm_mean_noise_span_length raise ValueError("This feature has been deprecated. We can't pack for t5") elif dataset_type == DatasetTypes.IMAGE_CAPTION_PAIRS: split_fn = split_pack_and_pad_pairs elif dataset_type == DatasetTypes.OCR: split_fn = split_pack_and_pad_ocr elif (dataset_type == DatasetTypes.VQAV2_TASK_FINETUNING) or (dataset_type == DatasetTypes.DOCVQA): split_fn = split_pack_and_pad_iqa_finetuning elif dataset_type == DatasetTypes.SFT: split_fn = split_pack_and_pad_sft elif dataset_type == DatasetTypes.WEB_DOCUMENTS: split_fn = split_pack_and_pad_webdocs mapper_kwargs["max_num_samples_per_document"] = max_num_samples_per_document mapper_kwargs["max_num_images_per_document"] = max_num_images_per_document mapper_with_args = partial(split_fn, **mapper_kwargs) return mapper_with_args def get_dataloaders( config: Parameters, rank: int, world_size: int, tokenizer, train_image_transforms, val_image_transforms, image_seq_len: int, ): logger.info("Getting the train dataloader") train_loader = get_dataloader_from_config( tokenizer=tokenizer, image_transforms=train_image_transforms, seed=config.data_param.train_seed, config=config, is_train=True, rank=rank, world_size=world_size, image_seq_len=image_seq_len, ) if config.hparams.do_validation: logger.info("Getting the validation dataloader") val_loader = get_dataloader_from_config( tokenizer=tokenizer, image_transforms=val_image_transforms, seed=config.data_param.val_seed, config=config, is_train=False, rank=rank, world_size=world_size, image_seq_len=image_seq_len, ) else: val_loader = None return train_loader, val_loader def load_hf_dataset(dataset_path): split_name = None config_name = None if ":" in dataset_path: dataset_path_splitted = dataset_path.split(":") if len(dataset_path_splitted) == 2: dataset_path, split_name = dataset_path_splitted elif len(dataset_path_splitted) == 3: dataset_path, config_name, split_name = dataset_path_splitted if os.path.exists(dataset_path): # a local path dataset can be of two kinds # 1. generated by `save_to_disk` and thus must be loaded with `load_from_disk` # 2. hub-like dataset, but which is not online # so we try the first and if it fails with `FileNotFoundError` (despite the path existing) we try the second try: hf_dataset = datasets.load_from_disk(dataset_path) except FileNotFoundError: if config_name is not None: hf_dataset = datasets.load_dataset(dataset_path, name=config_name) else: hf_dataset = datasets.load_dataset(dataset_path) else: if config_name is not None: hf_dataset = datasets.load_dataset( dataset_path, name=config_name, use_auth_token=os.environ.get("HF_TOKEN", True) ) else: hf_dataset = datasets.load_dataset(dataset_path, use_auth_token=os.environ.get("HF_TOKEN", True)) if split_name is not None: hf_dataset = hf_dataset[split_name] return hf_dataset def get_dataset_hf( dataset_config: DatasetParams, tokenizer, image_transform, is_train: bool = True, realtime_processing: bool = True, is_t5: bool = False, ): dataset_list = [] hf_datasets_paths = ( dataset_config.training_datasets_paths if is_train else dataset_config.validation_datasets_paths ) # hf_datasets_paths can be a list of paths, or a .txt file path that contains the paths if len(hf_datasets_paths) == 1 and str(hf_datasets_paths[0]).endswith(".txt"): with open(hf_datasets_paths[0], "r") as file_shards: hf_datasets_paths = [path for path in file_shards.read().split("\n") if path] for dataset_path in hf_datasets_paths: hf_dataset = load_hf_dataset(dataset_path=str(dataset_path)) is_paired_dataset = "meta" in hf_dataset[0] and "source" in hf_dataset[0] optional_kwargs_defaults = [ ("pad_dataset", True), ("max_num_samples_per_document", 1), ("t5_mlm_noise_density", 0.15), ("t5_mlm_mean_noise_span_length", 3), ("add_begin_of_doc_token", True), ("add_end_of_doc_token", True), ("max_num_images_per_document", None), ] optional_kwargs = {} for key, default in optional_kwargs_defaults: optional_kwargs[key] = getattr(dataset_config, key, default) if not realtime_processing: mapper_with_args = get_mapper( tokenizer=tokenizer, image_transform=image_transform, image_seq_len=dataset_config.image_seq_len, max_seq_len=dataset_config.max_seq_len, max_num_images=dataset_config.max_num_images, max_image_size=dataset_config.max_image_size, vision_encoder_max_image_size=dataset_config.vision_encoder_max_image_size, pre_split_scale_up_max=dataset_config.pre_split_scale_up_max, pre_split_scale_up_frequency=dataset_config.pre_split_scale_up_frequency, dataset_type=DatasetTypes.IMAGE_CAPTION_PAIRS if is_paired_dataset else DatasetTypes.WEB_DOCUMENTS, is_t5=is_t5, **optional_kwargs, ) hf_dataset = hf_dataset.map( mapper_with_args, batched=True, batch_size=dataset_config.map_batch_size, remove_columns=hf_dataset.column_names, num_proc=dataset_config.map_num_proc, ) dataset_list.append(hf_dataset) return dataset_list def get_dataset_webdataset( dataset_config: DatasetParams, is_train: bool = True, realtime_processing: bool = True, ): if not realtime_processing: raise NotImplementedError("WebDataset is only supported for realtime processing") webdataset_paths = dataset_config.training_datasets_paths if is_train else dataset_config.validation_datasets_paths if len(webdataset_paths) == 0: return None # webdataset_paths can be a list of paths/commands, or a .txt file path that contains the paths if len(webdataset_paths) == 1 and str(webdataset_paths[0]).endswith(".txt"): with open(webdataset_paths[0], "r") as file_shards: webdataset_paths = [path for path in file_shards.read().split("\n") if path] else: raise ValueError("WebDataset only supports a .txt file with the paths or the commands.") # Check if the paths/commands are valid checks = all([check_webdataset_command(path) for path in webdataset_paths]) if not checks: raise ValueError("WebDataset paths/commands are not valid. Please check the paths/commands.") combined_dataset = get_webdataset( urls=webdataset_paths, ds_type=dataset_config.dataset_type, batch_size=dataset_config.map_batch_size, shuffle_initial_urls_list=dataset_config.shuffle_initial_urls_list if is_train else False, shuffle_before_split_by_node_buffer_size=( dataset_config.shuffle_before_split_by_node_buffer_size if is_train else None ), shuffle_before_split_by_worker_buffer_size=( dataset_config.shuffle_before_split_by_worker_buffer_size if is_train else None ), shuffle_after_tarfile_to_samples_buffer_size=( dataset_config.shuffle_after_tarfile_to_samples_buffer_size if is_train else None ), shuffle_after_batching_buffer_size=dataset_config.shuffle_after_batching_buffer_size if is_train else None, ) return combined_dataset def get_dataset( dataset_config: DatasetParams, tokenizer=None, image_transform=None, is_train: bool = True, realtime_processing: bool = True, is_t5: bool = False, use_webdataset: bool = False, ): if use_webdataset: return get_dataset_webdataset( dataset_config=dataset_config, is_train=is_train, realtime_processing=realtime_processing, ) else: return get_dataset_hf( dataset_config=dataset_config, tokenizer=tokenizer, image_transform=image_transform, is_train=is_train, realtime_processing=realtime_processing, is_t5=is_t5, ) def get_dataloader( tokenizer, image_transforms, seed, num_workers=1, pin_memory=False, batch_size=10, is_train=True, persistent_workers=True, realtime_processing=False, # The following arguments only used for iterable dataset rank=None, world_size=None, # This argument is for controlling sample order randomness for Map-style Datasets when resuming a run sampler_rng=None, accumulate_datasets=False, model_name=None, data_param: Optional[DataParams] = None, image_seq_len=None, ): if is_train: select_n_examples = data_param.select_n_examples_train else: select_n_examples = data_param.select_n_examples_validation if data_param is None: raise ValueError("data_param must be provided") is_t5 = "t5" in model_name if model_name is not None else False dataset_list_map = {} # Try all possible datasets, if they don't have datasets paths in the config, they # will end up with an empty list for dataset_name in DatasetNames: curr_dataset_config = getattr(data_param, dataset_name.name.lower()) dataset_list_map[dataset_name.name.lower()] = get_dataset( dataset_config=curr_dataset_config, tokenizer=tokenizer, image_transform=image_transforms[dataset_name.name.lower()], is_train=is_train, realtime_processing=realtime_processing, is_t5=is_t5, use_webdataset=data_param.use_webdataset, ) if not realtime_processing: # => Important & gnarly: Set image transform based on a novel instance of `np.default_rng` seeded by the parent # seed and the current rank; because of the way DataLoader `num_workers` multiprocessing works in tandem with # the default `hf.dataset` Arrow backend + normal PyTorch `DataLoader, Sampler, and BatchSampler` behavior, # any "global" reference to an rng object that gets to this transform will get "pickled and copied over" to # each separate worker process. # # This wouldn't be a problem if we could simply just "reset" the randomness of the Dataset, but that's opaque # given the `hf.Dataset` wrapper; as such, we need to just handle the randomness ourselves by advancing the # random state the appropriate amount in the `__getitem__` of the MapWrapperDataset (as that's only other # place where we're sure we're in scope for a given worker's process block). full_dataset = datasets.concatenate_datasets( [dataset for dataset_list in dataset_list_map.values() for dataset in dataset_list] ) if select_n_examples is not None: full_dataset = full_dataset.select(range(select_n_examples)) transform_rng = np.random.default_rng(seed=[seed, rank]) # Wrap `full_dataset` in custom MapWrapperDataset, and initialize a ResumableSampler full_dataset = MapWrapperDataset(full_dataset, transform_rng) resume_sampler = ResumableSampler(full_dataset, sampler_rng) return DataLoaderForMappedWrapperDataset( full_dataset, batch_size=batch_size, sampler=resume_sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, persistent_workers=persistent_workers, ) else: realtime_processing_datasets = [] for dataset_name in DatasetNames: dataset_list_or_combined = dataset_list_map.get(dataset_name.name.lower(), []) if dataset_list_or_combined is None or ( isinstance(dataset_list_or_combined, list) and len(dataset_list_or_combined) == 0 ): continue if isinstance(dataset_list_or_combined, list): # If we have a list of datasets, we know those are hf datasets, so we can concatenate them combined_dataset = datasets.concatenate_datasets(dataset_list_or_combined) if len(combined_dataset) // max(num_workers, 1) < batch_size: raise ValueError( f"For real-time processing, len(dataset) [={len(combined_dataset)}] // num_workers" f" [={num_workers}] must be >= batch_size [={batch_size}]!" ) if select_n_examples is not None: combined_dataset = combined_dataset.select(range(select_n_examples)) wrapper_dataset_class = IterableWrapperHFDataset elif isinstance(dataset_list_or_combined, wds.pipeline.DataPipeline): combined_dataset = dataset_list_or_combined wrapper_dataset_class = IterableWrapperWebdataset else: raise ValueError("Type unrecognized") dataset_config: DatasetParams = getattr(data_param, dataset_name.name.lower()) dataset_kwargs = asdict(dataset_config) signature = inspect.signature(wrapper_dataset_class.__init__) dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if k in signature.parameters} iterable_dataset_instance = wrapper_dataset_class( combined_dataset, tokenizer=tokenizer, image_transform=image_transforms[dataset_name.name.lower()], batch_size=batch_size, seed=seed, shuffle=is_train, rank=rank, world_size=world_size, drop_last=True, is_t5=is_t5, image_seq_len=image_seq_len, **dataset_kwargs, ) realtime_processing_datasets.append(iterable_dataset_instance) full_dataset = CustomChainDataset( realtime_processing_datasets, num_workers, rank, accumulate_datasets=accumulate_datasets, proba_interleaving_dataset=data_param.proba_interleaving_dataset, is_train=is_train, ) return DataLoaderForIterableWrapperDataset( full_dataset, seed=seed, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, collate_fn=simple_collate, drop_last=True, rank=rank, world_size=world_size, ) def get_dataloader_from_config( tokenizer, image_transforms, image_seq_len, seed, config, is_train, rank=None, world_size=None, sampler_rng=None, ): dataloader = get_dataloader( tokenizer=tokenizer, image_transforms=image_transforms, seed=seed, num_workers=config.data_param.num_workers, pin_memory=config.data_param.pin_memory, batch_size=config.data_param.batch_size, is_train=is_train, persistent_workers=config.data_param.persistent_workers, realtime_processing=config.data_param.realtime_processing, rank=rank, world_size=world_size, sampler_rng=sampler_rng, accumulate_datasets=True if config.hparams.loss_weights_per_dataset is not None else False, model_name=config.hparams.model_name, data_param=config.data_param, image_seq_len=image_seq_len, ) return dataloader # Creates a Resumable Sampler for Map-style Datasets that predefines the set of indices to retrieve for # the given epoch (seeded with a generator!). We are using this instead of just seeding the "default Sampler" # so we can quickly "short-circuit" and bypass any and all "seen" indices (see `__iter__`). class ResumableSampler(Sampler): def __init__(self, data_source, sample_generator): super().__init__(data_source) self.data_source, self.indices = data_source, None # Note: `accelerate` hardcodes a search for an attribute named `generator`, which it then uses to synchronize # generators across all ranks on each call to `__iter__`; in our case, this is really bad as we want full # control over "sample" randomness (only for Map-style datasets). To get around this, it suffices to just # name the instance attributes anything other than `generator`... so that's what we do! self.sample_generator = sample_generator # For "instant" resuming self.n_seen_examples = 0 def set_state(self, n_seen_examples_per_worker): self.n_seen_examples = sum(n_seen_examples_per_worker.values()) def get_index_order(self): return torch.randperm(len(self.data_source), generator=self.sample_generator).tolist() def __iter__(self): self.indices = self.get_index_order() # Resume logic -> advance by `n_seen_examples` self.indices = self.indices[self.n_seen_examples :] yield from self.indices def __len__(self): return len(self.data_source) # Super simple wrapper around dataset; mostly for compatibility with IterableWrapperDataset class MapWrapperDataset(torch.utils.data.Dataset): def __init__(self, dataset, transform_rng): self.wrapped_dataset, self.transform_rng = dataset, transform_rng # For "instant" resuming self.n_seen_examples_per_worker = {} def state_dict(self): state_dict = {} # recurse into its dataset state_dict["wrapped_dataset"] = self.wrapped_dataset.state_dict() return state_dict def load_state_dict(self, state_dict): # recurse into its dataset self.wrapped_dataset.load_state_dict(state_dict["wrapped_dataset"]) @property def dataset(self): return self.wrapped_dataset def set_state(self, n_seen_examples_per_worker): self.n_seen_examples_per_worker = n_seen_examples_per_worker def __getitem__(self, idx): # Dummy dataset idx used for compatibility with CustomChainDataset dummy_dataset_idx = 0 worker_info = torch.utils.data.get_worker_info() n_seen_examples = self.n_seen_examples_per_worker.get(worker_info.id, 0) # If `n_seen_examples` is non-zero --> advance random state "quickly" then return new example! if n_seen_examples > 0: for _ in range(n_seen_examples): self.transform_rng.random() self.n_seen_examples_per_worker[worker_info.id] = 0 return dummy_dataset_idx, worker_info.id, self.wrapped_dataset[idx] def __len__(self): return len(self.wrapped_dataset) class IterableWrapperHFDataset(torch.utils.data.IterableDataset): def __init__( self, dataset, tokenizer, image_transform, batch_size, seed, dataset_type, dataset_name, image_seq_len, shuffle=True, rank=None, world_size=None, drop_last=True, is_t5=False, # Dataset specific params max_num_images=5, max_seq_len=256, max_image_size=384, vision_encoder_max_image_size=384, pre_split_scale_up_max=1.0, pre_split_scale_up_frequency=0.0, pad_dataset=True, mapper_batch_size=128, # Setting default to 0 as PMD doesn't need it and it is set to 0.5 in CM4 config by default max_num_samples_per_document=1, t5_mlm_noise_density=None, t5_mlm_mean_noise_span_length=None, add_begin_of_doc_token=True, add_end_of_doc_token=True, shuffle_after_packing=False, max_num_images_per_document=None, ): self.dataset = dataset self.mapper = get_mapper( tokenizer=tokenizer, image_transform=image_transform, image_seq_len=image_seq_len, max_seq_len=max_seq_len, max_num_images=max_num_images, max_image_size=max_image_size, vision_encoder_max_image_size=vision_encoder_max_image_size, pre_split_scale_up_max=pre_split_scale_up_max, pre_split_scale_up_frequency=pre_split_scale_up_frequency, pad_dataset=pad_dataset, max_num_samples_per_document=max_num_samples_per_document, dataset_type=dataset_type, is_t5=is_t5, t5_mlm_noise_density=t5_mlm_noise_density, t5_mlm_mean_noise_span_length=t5_mlm_mean_noise_span_length, add_begin_of_doc_token=add_begin_of_doc_token, add_end_of_doc_token=add_end_of_doc_token, max_num_images_per_document=max_num_images_per_document, ) self.batch_size = batch_size self.shuffle = shuffle self.seed = seed self.epoch = 0 self.rank = rank self.world_size = world_size self.mapper_batch_size = mapper_batch_size self.drop_last = drop_last self.shuffle_after_packing = shuffle_after_packing # To be initialized later in __iter__ self.rng = None # Resume Tracking --> Dict[worker_idx] -> Tuple[map_idx, key_idx]; `map_idx` lets us jumpstart! self.worker_idx_tracker = {} self.start_worker_idx = 0 self.dataset_name = dataset_name def set_state(self, worker_idx_tracker, start_worker_idx): self.worker_idx_tracker = worker_idx_tracker self.start_worker_idx = start_worker_idx def set_epoch(self, epoch): self.epoch = epoch def state_dict(self): state_dict = {} return state_dict def load_state_dict(self, state_dict): pass def _get_worker_id_and_worker_total_num(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: worker_total_num = 1 worker_id = 0 else: worker_total_num = worker_info.num_workers worker_id = worker_info.id worker_id = (worker_id + self.start_worker_idx) % worker_total_num return worker_id, worker_total_num def _get_worker_indices(self): sampler = torch.utils.data.DistributedSampler( self.dataset, num_replicas=self.world_size, rank=self.rank, shuffle=self.shuffle, seed=self.seed, drop_last=self.drop_last, ) sampler.set_epoch(self.epoch) # Get list of indices for current rank from the distributed sampler indices = list(iter(sampler)) worker_id, worker_total_num = self._get_worker_id_and_worker_total_num() # Take the subset of indices that belong to this worker # => It will look something like 0, 2, 4 when worker_id is 0 and num_workers is 2 worker_indices = indices[worker_id::worker_total_num] return worker_indices, worker_id def __iter__(self): # Dummy dataset idx used for compatibility with CustomChainDataset dummy_dataset_idx = 0 if self.rank is None or self.world_size is None: raise ValueError("rank and world_size must be provided") # Get worker indices & details for resuming... worker_indices, worker_id = self._get_worker_indices() num_worker_indices = len(worker_indices) # Set start idx of loop based on `self.worker_resume_idxs` map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {})) self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, map_start_idx] self.rng = np.random.default_rng(seed=self.rng_seed) for i in range(map_start_idx, num_worker_indices, self.mapper_batch_size): # Set seed for the worker according to worker index and the index and then reset it work # This needs to be done so that torch random crop is deterministic rng_state = torch.get_rng_state() torch.manual_seed(f"{self.seed}{worker_id}{i}") # Feed `worker_indices[i]` to mapper to ensure "deterministic randomness" that we don't have to track... curr_mapped_batch = self.mapper( self.dataset[worker_indices[i : i + self.mapper_batch_size]], prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i), ) torch.set_rng_state(rng_state) keys = list(curr_mapped_batch.keys()) overflow_batch_keys = overflow_batch.keys() # Check if overflow from previous batches is left, if yes, add it to the current batch # Specifically, we should prepend this overflow batch so as it goes out first and # current batch possibly becomes next overflow batch if len(overflow_batch_keys) > 0: if sorted(overflow_batch_keys) != sorted(keys): raise ValueError( "Overflow batch keys not equal to current keys. Make sure mapper is always returning" " dictionary with the same keys. " f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}" ) else: mapped_batch = {} if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch: total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0) max_num_images = max( overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0, ) max_height = max( overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0, ) max_width = max( overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0, ) padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) padded_pixel_attention_masks = torch.zeros( total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool ) start = 0 for batch in [overflow_batch, curr_mapped_batch]: if "pixel_values" not in batch: continue px = batch["pixel_values"] px_attn_mask = batch["pixel_attention_mask"] end = start + px.size(0) padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask start += px.size(0) mapped_batch["pixel_values"] = padded_image_tensor.contiguous() mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous() for key in keys: if key in ["pixel_values", "pixel_attention_mask"]: continue mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0) previous_overflow_batch = copy.deepcopy(overflow_batch) overflow_batch = {} else: previous_overflow_batch = {} mapped_batch = curr_mapped_batch first_key = keys[0] mapped_batch_length = len(mapped_batch[first_key]) if self.shuffle_after_packing: indices = list(range(mapped_batch_length)) self.rng.shuffle(indices) for key in mapped_batch.keys(): mapped_batch[key] = mapped_batch[key][indices, ...] if mapped_batch_length < self.batch_size: # We need to add more data to this batch to make it of size `self.batch_size` # Just setting mapped_batch to overflow_batch should be enough as the next iteration # will add more data to it overflow_batch = mapped_batch else: # Now, yield batches of size batch_size from the mapped batch for key_idx in range(0, mapped_batch_length, self.batch_size): # Set "reproducible" randomness self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx] self.rng = np.random.default_rng(seed=self.rng_seed) if i == map_start_idx and key_idx <= last_key_idx: # Handle Resume (only for "first" loop iteration) advance random state until `last_key_idx` self.rng.random() else: overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys} if len(overflow_batch[first_key]) != self.batch_size: # Last batch break else: dataset_state = { "worker_idx": worker_id, "map_start_idx": i, "last_key_idx": key_idx, "previous_overflow_batch": previous_overflow_batch, } yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch overflow_batch = {} class IterableWrapperWebdataset(torch.utils.data.IterableDataset): def __init__( self, dataset, tokenizer, image_transform, batch_size, seed, dataset_type, dataset_name, image_seq_len, shuffle=True, rank=None, world_size=None, drop_last=True, is_t5=False, # Dataset specific params max_num_images=5, max_seq_len=256, max_image_size=384, vision_encoder_max_image_size=384, pre_split_scale_up_max=1.0, pre_split_scale_up_frequency=0.0, pad_dataset=True, mapper_batch_size=128, # Setting default to 0 as PMD doesn't need it and it is set to 0.5 in CM4 config by default max_num_samples_per_document=1, t5_mlm_noise_density=None, t5_mlm_mean_noise_span_length=None, add_begin_of_doc_token=True, add_end_of_doc_token=True, shuffle_after_packing=False, max_num_images_per_document=None, ): self._webdataset = dataset self.dataset = iter(self._webdataset) self.mapper = get_mapper( tokenizer=tokenizer, image_transform=image_transform, image_seq_len=image_seq_len, max_seq_len=max_seq_len, max_num_images=max_num_images, max_image_size=max_image_size, vision_encoder_max_image_size=vision_encoder_max_image_size, pre_split_scale_up_max=pre_split_scale_up_max, pre_split_scale_up_frequency=pre_split_scale_up_frequency, pad_dataset=pad_dataset, max_num_samples_per_document=max_num_samples_per_document, dataset_type=dataset_type, is_t5=is_t5, t5_mlm_noise_density=t5_mlm_noise_density, t5_mlm_mean_noise_span_length=t5_mlm_mean_noise_span_length, add_begin_of_doc_token=add_begin_of_doc_token, add_end_of_doc_token=add_end_of_doc_token, max_num_images_per_document=max_num_images_per_document, ) self.batch_size = batch_size self.shuffle = shuffle self.seed = seed self.epoch = 0 self.rank = rank self.world_size = world_size self.mapper_batch_size = mapper_batch_size self.drop_last = drop_last self.shuffle_after_packing = shuffle_after_packing # To be initialized later in __iter__ self.rng = None # Resume Tracking --> Dict[worker_idx] -> Tuple[map_idx, key_idx]; `map_idx` lets us jumpstart! self.worker_idx_tracker = {} self.start_worker_idx = 0 self.dataset_name = dataset_name def set_state(self, worker_idx_tracker, start_worker_idx): self.worker_idx_tracker = worker_idx_tracker self.start_worker_idx = start_worker_idx def reset_state(self): for key in self.worker_idx_tracker.keys(): self.worker_idx_tracker[key] = (0, -1, {}) self.start_worker_idx = 0 def set_epoch(self, epoch): self.epoch = epoch self.dataset = iter(self._webdataset) # Reset dataset iterator def state_dict(self): state_dict = {} return state_dict def load_state_dict(self, state_dict): pass def _get_worker_id_and_worker_total_num(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: worker_total_num = 1 worker_id = 0 else: worker_total_num = worker_info.num_workers worker_id = worker_info.id worker_id = (worker_id + self.start_worker_idx) % worker_total_num return worker_id, worker_total_num def _get_worker_indices(self): sampler = torch.utils.data.DistributedSampler( self.dataset, num_replicas=self.world_size, rank=self.rank, shuffle=self.shuffle, seed=self.seed, drop_last=self.drop_last, ) sampler.set_epoch(self.epoch) # Get list of indices for current rank from the distributed sampler indices = list(iter(sampler)) worker_id, worker_total_num = self._get_worker_id_and_worker_total_num() # Take the subset of indices that belong to this worker # => It will look something like 0, 2, 4 when worker_id is 0 and num_workers is 2 worker_indices = indices[worker_id::worker_total_num] return worker_indices, worker_id def __iter__(self): # Dummy dataset idx used for compatibility with CustomChainDataset dummy_dataset_idx = 0 if self.rank is None or self.world_size is None: raise ValueError("rank and world_size must be provided") # Relic from previous implementation - but needed for rng seed worker_id, worker_total_num = self._get_worker_id_and_worker_total_num() # Relic from previous implementation - but needed for rng seed map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {})) # Relic from previous implementation - but needed for rng seed i = map_start_idx # Initialize rng_seed self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i] self.rng = np.random.default_rng(seed=self.rng_seed) while True: # Set seed for the worker according to worker index and the index and then reset it work # This needs to be done so that torch random crop is deterministic rng_state = torch.get_rng_state() torch.manual_seed(f"{self.seed}{worker_id}{i}") try: next_batch = next(self.dataset) i += 1 except StopIteration: logger.info( f"{self.dataset_name.name.lower()} has finished one epoch and is moving on to the next one." f" (epoch={self.epoch} - rank={self.rank} - worker_id={worker_id})" ) break curr_mapped_batch = self.mapper( next_batch, prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i), ) torch.set_rng_state(rng_state) keys = list(curr_mapped_batch.keys()) overflow_batch_keys = overflow_batch.keys() # Check if overflow from previous batches is left, if yes, add it to the current batch # Specifically, we should prepend this overflow batch so as it goes out first and # current batch possibly becomes next overflow batch if len(overflow_batch_keys) > 0: if sorted(overflow_batch_keys) != sorted(keys): raise ValueError( "Overflow batch keys not equal to current keys. Make sure mapper is always returning" " dictionary with the same keys. " f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}" ) else: mapped_batch = {} if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch: total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0) max_num_images = max( overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0, ) max_height = max( overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0, ) max_width = max( overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0, curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0, ) padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) padded_pixel_attention_masks = torch.zeros( total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool ) start = 0 for batch in [overflow_batch, curr_mapped_batch]: if "pixel_values" not in batch: continue px = batch["pixel_values"] px_attn_mask = batch["pixel_attention_mask"] end = start + px.size(0) padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask start += px.size(0) mapped_batch["pixel_values"] = padded_image_tensor.contiguous() mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous() for key in keys: if key in ["pixel_values", "pixel_attention_mask"]: continue mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0) overflow_batch = {} else: mapped_batch = curr_mapped_batch first_key = keys[0] mapped_batch_length = len(mapped_batch[first_key]) if self.shuffle_after_packing: indices = list(range(mapped_batch_length)) self.rng.shuffle(indices) for key in mapped_batch.keys(): mapped_batch[key] = mapped_batch[key][indices, ...] if mapped_batch_length < self.batch_size: # We need to add more data to this batch to make it of size `self.batch_size` # Just setting mapped_batch to overflow_batch should be enough as the next iteration # will add more data to it overflow_batch = mapped_batch else: # Now, yield batches of size batch_size from the mapped batch for key_idx in range(0, mapped_batch_length, self.batch_size): # Set "reproducible" randomness self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx] self.rng = np.random.default_rng(seed=self.rng_seed) overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys} if len(overflow_batch[first_key]) != self.batch_size: # Last batch break else: dataset_state = { "worker_idx": worker_id, "map_start_idx": i, "last_key_idx": key_idx, "previous_overflow_batch": {}, } yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch overflow_batch = {} class CustomChainDataset(torch.utils.data.IterableDataset): r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. For validation set: This class iterates over each dataset one by one. One dataset is iterated over completely before moving to the next one. Args: datasets (iterable of IterableDataset): datasets to be chained together num_workers (int): number of workers to use for loading data rank (int): rank of the current process accumulate_datasets (bool): whether to accumulate the datasets or not proba_interleaving_dataset (list of float): probability of interleaving each dataset, first probability is of the PMD, second one of CM4 is_train (bool): whether the dataset is used for training or not. Setting this to false will ignore accumulate_datasets and proba_interleaving_dataset parameters and set them to False and None. See commentary above for validation set to understand the behavior of this class when `is_train` is False. """ def __init__( self, datasets, num_workers, rank, accumulate_datasets=False, proba_interleaving_dataset=None, is_train=True ): super(CustomChainDataset, self).__init__() for d in datasets: if not isinstance(d, torch.utils.data.IterableDataset): raise ValueError(f"CustomChainDataset only supports IterableDataset, but got {type(d)}") self.datasets = datasets self.num_workers = num_workers if num_workers > 1 else 1 self.num_datasets = len(self.datasets) self.is_train = is_train if self.is_train is False: if accumulate_datasets is True or proba_interleaving_dataset is not None: logger.warn("accumulate_datasets and proba_interleaving_dataset are ignored when is_train is False") self.accumulate_datasets = False self.dataset_proba = None else: self.accumulate_datasets = accumulate_datasets if not self.accumulate_datasets: if proba_interleaving_dataset is not None: self.dataset_proba = np.asarray(proba_interleaving_dataset) else: self.dataset_proba = np.full((self.num_datasets), 1 / self.num_datasets, dtype=float) if abs(self.dataset_proba.sum() - 1) > 0.001: # Allow a tolerance for floating points rounding errors. raise ValueError("proba_interleaving_dataset must sum to 1") self.epoch = 0 self.seed = sum(dataset.seed for dataset in self.datasets) self.rank = rank # state-related attributes self.start_worker_id = 0 self.chain_dataset_last_idx_tracker = {} self.reset_state() def reset_state(self): self.chain_dataset_last_idx_tracker = {worker_id: 0 for worker_id in range(self.num_workers)} def update_state(self, dataset_state): self.chain_dataset_last_idx_tracker[dataset_state["worker_idx"]] = dataset_state["chain_dataset_last_idx"] def load_resume_states(self, resumable_states): for idx, d in enumerate(self.datasets): worker_idx_tracker, start_worker_id = resumable_states[idx] d.set_state(worker_idx_tracker, start_worker_id) self.start_worker_id = start_worker_id def state_dict(self): state_dict = {} state_dict["chain_dataset_last_idx_tracker"] = self.chain_dataset_last_idx_tracker # recurse into its datasets state_dict["datasets"] = [d.state_dict() for d in self.datasets] return state_dict def load_state_dict(self, state_dict): for key in state_dict["chain_dataset_last_idx_tracker"].keys(): if key in self.chain_dataset_last_idx_tracker: self.chain_dataset_last_idx_tracker[key] = state_dict["chain_dataset_last_idx_tracker"][key] else: self.chain_dataset_last_idx_tracker[key] = 0 # recurse into its datasets for idx, d in enumerate(self.datasets): d.load_state_dict(state_dict["datasets"][idx]) def set_epoch(self, epoch): # TODO: change this and fix trainer as well when epoch logic # described in iter is implemented self.epoch = epoch for d in self.datasets: d.set_epoch(epoch) def _get_worker_id(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: worker_total_num = 1 worker_id = 0 else: worker_total_num = worker_info.num_workers worker_id = worker_info.id worker_id = (worker_id + self.start_worker_id) % worker_total_num return worker_id def __iter__(self): ds_iterators = [iter(d) for d in self.datasets] worker_id = self._get_worker_id() # Needed for dataset accumulation. Allows chain dataset to start at the right # dataset_idx, as it needs to know what was the last idx was. # Note that for validation set, this will always start from zero. dataset_idx = self.chain_dataset_last_idx_tracker[worker_id] while True: rng_seed = [ self.seed, self.epoch, self.rank, worker_id, self.chain_dataset_last_idx_tracker[worker_id], ] rng = np.random.default_rng(seed=rng_seed) if self.is_train: if self.accumulate_datasets: dataset_idx = (dataset_idx + 1) % self.num_datasets else: dataset_idx = rng.choice(np.arange(0, self.num_datasets), p=self.dataset_proba) try: _, dataset_name, dataset_state, batch = next(ds_iterators[dataset_idx]) self.chain_dataset_last_idx_tracker[worker_id] += 1 dataset_state["chain_dataset_last_idx"] = self.chain_dataset_last_idx_tracker[worker_id] yield dataset_idx, dataset_name, dataset_state, batch except StopIteration: if not self.is_train and dataset_idx < self.num_datasets - 1: dataset_idx += 1 else: self.epoch += 1 self.reset_state() for d in self.datasets: d.reset_state() d.set_epoch(self.epoch) ds_iterators = [iter(d) for d in self.datasets] # TODO: Move epoch logic here instead of training loop # ie: make an infinite dataloader that keeps track of the epochs of # each dataset # this class is currently not being maintained class DataLoaderForMappedWrapperDataset(torch.utils.data.DataLoader): def state_dict(self): state_dict = {} # recurse into its dataset state_dict["dataset"] = self.dataset.state_dict() return state_dict def load_state_dict(self, state_dict): # recurse into its dataset self.dataset.load_state_dict(state_dict["dataset"]) # These are notes from IterableResumableState that got folded into # DataLoaderForIterableWrapperDataset, but the comments are still relevant: # # **IMPORTANT**: Serializing "transform" randomness is really really difficult to get right because any "global" # generators we initialize are not actually going to be respected in any DataLoader with `num_workers > 1`. When you # run a DataLoader w/ multiprocessing, each "rng" generator is "copied" over to the separate process; setting that # generator only works _within_ the parts of the Dataset/Sampler/BatchSampler/DataLoader that are actually in scope # while an individual worker process is live... and there's no way I can currently think of to cleanly do that... class DataLoaderForIterableWrapperDataset(torch.utils.data.DataLoader): def __init__(self, dataset, seed, **kwargs): self.seed = seed self.rank = kwargs.pop("rank", None) self.world_size = kwargs.pop("world_size", None) self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if self.rank is None or self.world_size is None: try: state = AcceleratorState() self.rank = state.process_index self.device = state.device self.world_size = state.num_processes except ValueError: # This will fail later on when we try to create the IterableWrapperDataset pass if self.world_size == 1: # we normally don't use one process, but when we do in testing there can be a problem # with Too many open files issue, and switching to "file_system" strategy overcomes it. torch.multiprocessing.set_sharing_strategy("file_system") dataset.rank = self.rank dataset.world_size = self.world_size super().__init__(dataset, **kwargs) self.dataset_count = len(self.dataset.datasets) self.worker_idx_tracker = [{} for i in range(self.dataset_count)] self.next_worker_idx = [0 for i in range(self.dataset_count)] self.reset_state() def reset_state(self): for dataset_idx in range(self.dataset_count): self.worker_idx_tracker[dataset_idx] = {} self.next_worker_idx[dataset_idx] = 0 # TODO: Aman: It looks like this has somewhat changed from last I checked. On slack, we # discussed to not have the states in the dataset themselves but just on the dataloader and # have current state returned from __iter__ of the dataset. # # The reason for this is what we discussed on slack related to dataset object being unique # to each worker and being passed around. So, the current dataset object in your dataloader # possibly won't give you correct state rather the state it had when it was initialized. Are # we confident that it will work this way? self.dataset.reset_state() def update_state(self, dataset_idx, dataset_state): self.dataset.update_state(dataset_state) self.worker_idx_tracker[dataset_idx][dataset_state["worker_idx"]] = ( dataset_state["map_start_idx"], dataset_state["last_key_idx"], dataset_state["previous_overflow_batch"], ) for dataset_idx in range(self.dataset_count): self.next_worker_idx[dataset_idx] = (dataset_state["worker_idx"] + 1) % self.num_workers def set_epoch(self, epoch): # Ensure each epoch has a different sequence to sample from self.dataset.set_epoch(epoch) def load_resume_states(self): self.dataset.load_resume_states(self.get_resume_states()) def get_resume_state(self, dataset_idx): return self.worker_idx_tracker[dataset_idx], self.next_worker_idx[dataset_idx] def get_resume_states(self): return [ [self.worker_idx_tracker[dataset_idx], self.next_worker_idx[dataset_idx]] for dataset_idx in range(self.dataset_count) ] def state_dict(self): state_dict = {} state_dict["worker_idx_tracker"] = self.worker_idx_tracker state_dict["next_worker_idx"] = self.next_worker_idx # recurse into its dataset state_dict["dataset"] = self.dataset.state_dict() return state_dict def load_state_dict(self, state_dict): self.worker_idx_tracker = state_dict["worker_idx_tracker"] self.next_worker_idx = state_dict["next_worker_idx"] # recurse into its dataset self.dataset.load_state_dict(state_dict["dataset"]) def save_state(self, path): """ Saves state_dict to `m4_states_{process_index}.pkl` """ path = Path(path) path.mkdir(parents=True, exist_ok=True) location = path / f"m4_states_{self.rank}.pkl" logger.info(f"Saving the DL states to {location}") torch.save(self.state_dict(), location) def load_state(self, path): """ Loads the state_dict at `m4_states_{process_index}.pkl` """ location = Path(path) / f"m4_states_{self.rank}.pkl" logger.info(f"Loading the DL states from {location}") self.load_state_dict(torch.load(location)) def __iter__(self): main_iterator = super().__iter__() stopped = torch.tensor(0.0, device=self.device) while True: try: batch = next(main_iterator) except StopIteration: stopped += 1 # check that dist is initialized in case this DL w/ world_size>1 was used w/o distributed environment if self.world_size > 1 and torch.distributed.is_initialized(): torch.distributed.all_reduce(stopped, op=torch.distributed.ReduceOp.SUM) # stop iterating if one or more processes stopped to avoid blocking if stopped > 0: break yield batch # This class isn't being maintained (or tested) at the moment class MappedResumableState: def __init__(self, train_seed, val_seed, world_size, rank, dataloader_num_workers, using_iterable_dataset=False): self.train_seed, self.val_seed, self.world_size, self.rank = train_seed, val_seed, world_size, rank self.dataloader_num_workers, self.using_iterable_dataset = dataloader_num_workers, using_iterable_dataset self.epoch = None self.n_seen_examples_per_worker = {w: 0 for w in range(self.dataloader_num_workers)} # Create generators for the various sources of randomness self.sampler_rng, self.sampler_rng_val = torch.Generator(), torch.Generator() self.sampler_rng.manual_seed(self.train_seed) self.sampler_rng_val.manual_seed(self.val_seed) # Create instance variables to track the `sampler_rng` states (see comment in `set_epoch` for why!) self.sampler_rng_state = self.sampler_rng.get_state() self.sampler_rng_val_state = self.sampler_rng_val.get_state() def set_epoch(self, epoch): # At the start of each epoch, the iter(sampler) gets called, which advances the random state; this isn't # great, as it means that if we save the random state post-hoc, when we resume, we're "skipping" to the # next epoch. To preserve the _same_ order within an epoch, we actually need to save the random state # prior to the call to __iter__! if self.epoch != epoch: self.epoch = epoch self.sampler_rng_state = self.sampler_rng.get_state() self.sampler_rng_val_state = self.sampler_rng_val.get_state() def update_state(self, state_trackers): # Gate based on `self.using_iterable_dataset` # For a Map-style dataset, `state_trackers` is a Tensor containing all worker_ids responsible for fetching # the given examples of the batch; we'll be using this to increment `n_seen_examples_per_worker`. We'll # be using the key assumption that a single worker generates all elements for the batch! self.n_seen_examples_per_worker[state_trackers[0].item()] += state_trackers.numel() def update_next_worker_idx(self, curr_worker_idx): self.next_worker_idx = (curr_worker_idx + 1) % self.dataloader_num_workers def get_resume_state(self): # Return minimal `state` to "fast-forward" a given Dataset/DataLoader return self.n_seen_examples_per_worker def state_dict(self): return { "epoch": self.epoch, "world_size": self.world_size, "dataloader_num_workers": self.dataloader_num_workers, "sampler_rng_state": self.sampler_rng_state, "sampler_rng_state_checksum": self.sampler_rng_state.sum(), "sampler_rng_val_state": self.sampler_rng_val_state, "sampler_rng_val_state_checksum": self.sampler_rng_val_state.sum(), "n_seen_examples_per_worker": self.n_seen_examples_per_worker, } def load_state_dict(self, state_dict): # Verify same "world size" and "num workers" on a resume! if self.world_size != state_dict["world_size"]: raise ValueError(f"Current world_size {self.world_size} != Loaded world_size {state_dict['world_size']}") if self.dataloader_num_workers != state_dict["dataloader_num_workers"]: raise ValueError( f"Current num_workers `{self.dataloader_num_workers}` != Loaded num_workers" f" `{state_dict['dataloader_num_workers']}`" ) # Set epoch self.epoch = state_dict["epoch"] # Set `sampler_rng` self.sampler_rng.set_state(state_dict["sampler_rng_state"]) self.sampler_rng_val.set_state(state_dict["sampler_rng_val_state"]) # Set `n_seen_examples_per_worker` self.n_seen_examples_per_worker = state_dict["n_seen_examples_per_worker"]