modules/SwissArmyTransformer/sat/data_utils/configure_data.py (343 lines of code) (raw):

# -*- encoding: utf-8 -*- ''' @File : configure_data.py @Time : 2021/01/11 23:28:38 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import copy import math import os import random import sys from bisect import bisect_right from functools import partial import numpy as np import torch from torch.utils import data from torch.utils.data import ChainDataset, IterableDataset from sat import mpu from sat.helpers import print_all, print_rank0 from .samplers import DistributedBatchSampler def make_data_loader(dataset, batch_size, args, split, collate_fn=None): world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) distributed = world_size > 1 # if IterableDataset, assume everything is properly configured. (pre-sharded) if isinstance(dataset, IterableDataset): if split in ['val', 'test'] and args.strict_eval: raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.') args.val_last_shape = [1] * world_size # just fake it, not actually used args.val_drop_number = 0 args.test_last_shape = [1] * world_size args.test_drop_number = 0 per_rank_batch_size = None if args.iterable_dataset == 'custom' else batch_size//world_size return torch.utils.data.DataLoader( dataset, batch_size=per_rank_batch_size, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn, prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, timeout=10 ) sampler = torch.utils.data.SequentialSampler(dataset) drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number. # the GPUs in the same model parallel group receive the same data if distributed: # TODO reformat this, but it is not urgent gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1) batch_sampler = DistributedBatchSampler(sampler, batch_size, drop_last, rank, world_size, gradient_accumulation_steps=gradient_accumulation_steps) else: batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last) last_len = len(dataset) % batch_size batch_per_worker = batch_size // world_size last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch if last_len != 0: if last_len % batch_per_worker != 0: last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch) drop_number = world_size - ((last_len-1)//batch_per_worker + 1) # other processes get nothing, but append 1 for running. will drop later according to drop_number. for j in range(drop_number): last_shape.append(1) else: drop_number = 0 if split=='val': args.val_last_shape = last_shape args.val_drop_number = drop_number elif split=='test': args.test_last_shape = last_shape args.test_drop_number = drop_number data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn, prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, ) return data_loader def make_dataset_full(path, split, args, create_dataset_function, dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs): """function to create datasets+tokenizers for common options""" print_all('make dataset ' + str(path), level='DEBUG') assert isinstance(path, list) if (is_train_data and args.iterable_dataset) or (not is_train_data and args.iterable_dataset_eval): # cannot indexed # the random mapping is flexible and efficient, but sometimes we have pratical issue # For instance, someone just gives you a iterable dataset, e.g. webdataset from .webds import ConfiguredResampledShards, DataPipeline valid_types = (ConfiguredResampledShards, DataPipeline) assert split[0] == 1, 'Iterable dataset cannot auto split.' ds = [] for p in path: d = create_dataset_function(p, args) assert isinstance(d, valid_types) ds.append(d) # ds = ChainDataset(ds) # please merge them in a url if chain if batch_from_same_dataset: assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.' ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size) return ds if split is None: split = [1.] if not should_split(split): ds = [] for p in path: d = create_dataset_function(p, args) ds.append(d) ds = ConcatDataset(ds, weights=dataset_weights) if random_mapping: if args.epochs is not None: # not auto-scale, but use a given number of epoches. ds = RandomDataset(ds, scale=args.epochs, seed=args.seed) else: world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) if is_train_data: # only train-dataset will set this to True, # so we enlarge it to make sure that the data is sufficient. scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds)) else: scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds)) ds = RandomMappingDataset(ds, scale=scale) return ds else: # must first split datasets, then reweight/concat, finally random-mapping. # this order avoids overlapping. train_ds, valid_ds, test_ds = [], [], [] for p in path: d = create_dataset_function(p, args) if should_split(split): dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed) train_ds.append(dtrain) valid_ds.append(dvalid) test_ds.append(dtest) train_ds = ConcatDataset(train_ds, weights=dataset_weights) valid_ds = ConcatDataset(valid_ds, weights=dataset_weights) test_ds = ConcatDataset(test_ds, weights=dataset_weights) if random_mapping: world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds)) train_ds = RandomMappingDataset(train_ds, scale=scale) scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds)) valid_ds = RandomMappingDataset(valid_ds, scale=scale) test_ds = RandomMappingDataset(test_ds) return train_ds, valid_ds, test_ds def make_loaders(args, create_dataset_function, collate_fn=None): """makes training/val/test Args: args.train_data, args.valid_data, args.test_data: str. Paths to the dataset. args.split: str. format: "8,1,1". how to split train_data. args.dataset_type: use to create the right datasets. """ make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset) world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) batch_size = args.batch_size * world_size eval_batch_size = batch_size if args.eval_batch_size is not None: eval_batch_size = args.eval_batch_size * world_size split = get_split(args) data_set_args = { 'path': args.train_data, 'split': split, } eval_set_args = copy.copy(data_set_args) eval_set_args['split'] = [1.] # make datasets splits and tokenizer train = None valid = None test = None if args.train_data is not None: train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True) if should_split(split): train, valid, test = train # make training and val dataset if necessary if valid is None and args.valid_data is not None: eval_set_args['path'] = args.valid_data valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval) if test is None and args.test_data is not None: eval_set_args['path'] = args.test_data test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval) # wrap datasets with data loader if train is not None and args.batch_size > 0: train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn) args.do_train = True else: args.do_train = False eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size if valid is not None: valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn) args.do_valid = True else: args.do_valid = False if test is not None: test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn) args.do_test = True else: args.do_test = False return train, valid, test def get_split(args): """ Get dataset splits from comma separated string list """ splits = [] if args.split.find(',') != -1: splits = [float(s) for s in args.split.split(',')] elif args.split.find('/') != -1: splits = [float(s) for s in args.split.split('/')] else: splits = [float(args.split)] split_total = sum(splits) if split_total < 1.: splits.append(1-split_total) while len(splits) < 3: splits.append(0.) splits = splits[:3] if args.valid_data is not None: splits[1] = 0. if args.test_data is not None: splits[2] = 0. final_sum = sum(splits) return [s/final_sum for s in splits] def should_split(split): """ given split proportions checks if should split Examples: >>> should_split([10,0,0]) False >>> should_split([1,.1,.2]) True """ return max(split) / sum(split) != 1. def split_ds(ds, split=[.8,.2,.0], block_size = 10000, seed=131): """ Split a dataset into subsets given proportions of how much to allocate per split. If a split is 0% returns None for that split. Purpose: Useful for creating train/val/test splits Arguments: ds (Dataset or array-like): Data to be split. split (1D array-like): proportions to split `ds`. `sum(splits) != 0` shuffle (boolean): Randomly split dataset. Default: True """ split_sum = sum(split) if split_sum == 0: raise Exception('Split cannot sum to 0.') split = np.array(split) split /= split_sum assert block_size <= len(ds) start_idx = 0 residual_idx = 0 rtn_ds = [None]*len(split) rng = np.random.default_rng(seed) indices = rng.permutation(np.array(range(block_size))) for i, f in enumerate(split): if f != 0: proportion = block_size*split[i] residual_idx += proportion % 1 split_ = int(int(proportion) + residual_idx) rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size) start_idx += split_ residual_idx %= 1 return rtn_ds class ConcatDataset(data.Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (sequence): List of datasets to be concatenated. """ @staticmethod def cumsum(sequence, weights): r, s = [], 0 for i, e in enumerate(sequence): l = int(len(e) * weights[i]) r.append(l + s) s += l return r def __init__(self, datasets, weights=None, **kwargs): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) if weights is None: self.weights = [1] * len(self.datasets) else: self.weights = weights self.cumulative_sizes = self.cumsum(self.datasets, self.weights) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = sample_idx % len(self.datasets[dataset_idx]) return self.datasets[dataset_idx][sample_idx] class RandomMappingDataset(data.Dataset): ''' Dataset wrapper to randomly mapping indices to original order. Will also enlarge the length ''' def __init__(self, ds, scale=200, **kwargs): self.wrapped_data = ds self.scale = scale def __len__(self): return len(self.wrapped_data) * self.scale def __getitem__(self, index): rng = random.Random(index) rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) index = rng.randint(len(self.wrapped_data)) return self.wrapped_data[index] class RandomDataset(data.Dataset): ''' Dataset wrapper to randomly mapping indices to original order. The indices are pre-processed. Will also enlarge the length ''' def __init__(self, ds, scale=200, seed=131, **kwargs): self.wrapped_data = ds self.scale = scale self.indices = np.random.default_rng(seed).permutation(np.array(range(len(ds)))) def __len__(self): return len(self.wrapped_data) * self.scale def __getitem__(self, index): return self.wrapped_data[int(self.indices[index % len(self.wrapped_data)])] class BlockedRandomSplitDataset(data.Dataset): ''' Dataset wrapper to access a subset of another dataset. Use block algorithm to reduce memory. In each block, using the `indices` items. ''' def __init__(self, ds, indices, block_size, **kwargs): if type(indices) is not np.ndarray: indices = np.array(indices) indices = np.sort(indices) self.block_size = block_size self.wrapped_data = ds self.wrapped_data_len = len(ds) self.indices = indices self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size)) def __len__(self): return self.len def __getitem__(self, index): return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)]] class AlterDataset(IterableDataset): def __init__(self, datasets, weights=None, seed=0, batch_from_same_dataset=False, batch_size=1): super().__init__() self.seed = seed self.datasets = datasets self.batch_from_same_dataset = batch_from_same_dataset self.batch_size = batch_size # only used when batch_from_same_dataset is True if weights is None: self.weights = [1. / len(self.datasets)] * len(self.datasets) else: s = sum(weights) self.weights = [w / s for w in weights] def __iter__(self): iterators = [iter(d) for d in self.datasets] # Assume that all datasets iterate follow the seed (mprank, seed, dataloader-worker-id(0 if not used in dataloader)), auto-detect at iter() try: from sat.mpu import get_data_parallel_rank dp_rank = get_data_parallel_rank() except Exception: dp_rank = 0 if self.batch_from_same_dataset: rng = np.random.default_rng(seed=[self.seed]) else: rng = np.random.default_rng(seed=[dp_rank, self.seed]) # sampling according to weights from streaming data while True: index = rng.choice(len(iterators), p=self.weights) # if stop iteration, remove the iterator try: if self.batch_from_same_dataset: # we need to make sure the consecutive batch_size samples are from the same iterable dataset. # but accumulate grad does not work. for i in range(self.batch_size - 1): yield next(iterators[index]) yield next(iterators[index]) except StopIteration: del iterators[index] del self.weights[index] if len(iterators) == 0: break s = sum(self.weights) self.weights = [w / s for w in self.weights] from sat.helpers import print_rank0 print_rank0(f'AlterDataset: remove a dataset, {len(iterators)} left.')