modules/SwissArmyTransformer/sat/data_utils/configure_data.py (343 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : configure_data.py
@Time : 2021/01/11 23:28:38
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import copy
import math
import os
import random
import sys
from bisect import bisect_right
from functools import partial
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import ChainDataset, IterableDataset
from sat import mpu
from sat.helpers import print_all, print_rank0
from .samplers import DistributedBatchSampler
def make_data_loader(dataset, batch_size, args, split, collate_fn=None):
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1
# if IterableDataset, assume everything is properly configured. (pre-sharded)
if isinstance(dataset, IterableDataset):
if split in ['val', 'test'] and args.strict_eval:
raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
args.val_last_shape = [1] * world_size # just fake it, not actually used
args.val_drop_number = 0
args.test_last_shape = [1] * world_size
args.test_drop_number = 0
per_rank_batch_size = None if args.iterable_dataset == 'custom' else batch_size//world_size
return torch.utils.data.DataLoader(
dataset,
batch_size=per_rank_batch_size,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
timeout=10
)
sampler = torch.utils.data.SequentialSampler(dataset)
drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.
# the GPUs in the same model parallel group receive the same data
if distributed: # TODO reformat this, but it is not urgent
gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
batch_sampler = DistributedBatchSampler(sampler,
batch_size,
drop_last,
rank,
world_size,
gradient_accumulation_steps=gradient_accumulation_steps)
else:
batch_sampler = torch.utils.data.BatchSampler(sampler,
batch_size,
drop_last)
last_len = len(dataset) % batch_size
batch_per_worker = batch_size // world_size
last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
if last_len != 0:
if last_len % batch_per_worker != 0:
last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
# other processes get nothing, but append 1 for running. will drop later according to drop_number.
for j in range(drop_number):
last_shape.append(1)
else:
drop_number = 0
if split=='val':
args.val_last_shape = last_shape
args.val_drop_number = drop_number
elif split=='test':
args.test_last_shape = last_shape
args.test_drop_number = drop_number
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
)
return data_loader
def make_dataset_full(path, split, args, create_dataset_function,
dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
"""function to create datasets+tokenizers for common options"""
print_all('make dataset ' + str(path), level='DEBUG')
assert isinstance(path, list)
if (is_train_data and args.iterable_dataset) or (not is_train_data and args.iterable_dataset_eval): # cannot indexed
# the random mapping is flexible and efficient, but sometimes we have pratical issue
# For instance, someone just gives you a iterable dataset, e.g. webdataset
from .webds import ConfiguredResampledShards, DataPipeline
valid_types = (ConfiguredResampledShards, DataPipeline)
assert split[0] == 1, 'Iterable dataset cannot auto split.'
ds = []
for p in path:
d = create_dataset_function(p, args)
assert isinstance(d, valid_types)
ds.append(d)
# ds = ChainDataset(ds) # please merge them in a url if chain
if batch_from_same_dataset:
assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
return ds
if split is None:
split = [1.]
if not should_split(split):
ds = []
for p in path:
d = create_dataset_function(p, args)
ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)
if random_mapping:
if args.epochs is not None: # not auto-scale, but use a given number of epoches.
ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
else:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
if is_train_data:
# only train-dataset will set this to True,
# so we enlarge it to make sure that the data is sufficient.
scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
else:
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
ds = RandomMappingDataset(ds, scale=scale)
return ds
else:
# must first split datasets, then reweight/concat, finally random-mapping.
# this order avoids overlapping.
train_ds, valid_ds, test_ds = [], [], []
for p in path:
d = create_dataset_function(p, args)
if should_split(split):
dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
train_ds.append(dtrain)
valid_ds.append(dvalid)
test_ds.append(dtest)
train_ds = ConcatDataset(train_ds, weights=dataset_weights)
valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
test_ds = ConcatDataset(test_ds, weights=dataset_weights)
if random_mapping:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
train_ds = RandomMappingDataset(train_ds, scale=scale)
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
valid_ds = RandomMappingDataset(valid_ds, scale=scale)
test_ds = RandomMappingDataset(test_ds)
return train_ds, valid_ds, test_ds
def make_loaders(args, create_dataset_function, collate_fn=None):
"""makes training/val/test
Args:
args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
args.split: str. format: "8,1,1". how to split train_data.
args.dataset_type: use to create the right datasets.
"""
make_dataset = partial(make_dataset_full,
create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
batch_size = args.batch_size * world_size
eval_batch_size = batch_size
if args.eval_batch_size is not None:
eval_batch_size = args.eval_batch_size * world_size
split = get_split(args)
data_set_args = {
'path': args.train_data,
'split': split,
}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
# make datasets splits and tokenizer
train = None
valid = None
test = None
if args.train_data is not None:
train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
if should_split(split):
train, valid, test = train
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
args.do_test = True
else:
args.do_test = False
return train, valid, test
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = []
if args.split.find(',') != -1:
splits = [float(s) for s in args.split.split(',')]
elif args.split.find('/') != -1:
splits = [float(s) for s in args.split.split('/')]
else:
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
if args.valid_data is not None:
splits[1] = 0.
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s/final_sum for s in splits]
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split) / sum(split) != 1.
def split_ds(ds, split=[.8,.2,.0], block_size = 10000, seed=131):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
Purpose: Useful for creating train/val/test splits
Arguments:
ds (Dataset or array-like): Data to be split.
split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
shuffle (boolean): Randomly split dataset. Default: True
"""
split_sum = sum(split)
if split_sum == 0:
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
assert block_size <= len(ds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
rng = np.random.default_rng(seed)
indices = rng.permutation(np.array(range(block_size)))
for i, f in enumerate(split):
if f != 0:
proportion = block_size*split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size)
start_idx += split_
residual_idx %= 1
return rtn_ds
class ConcatDataset(data.Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated.
"""
@staticmethod
def cumsum(sequence, weights):
r, s = [], 0
for i, e in enumerate(sequence):
l = int(len(e) * weights[i])
r.append(l + s)
s += l
return r
def __init__(self, datasets, weights=None, **kwargs):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
if weights is None:
self.weights = [1] * len(self.datasets)
else:
self.weights = weights
self.cumulative_sizes = self.cumsum(self.datasets, self.weights)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % len(self.datasets[dataset_idx])
return self.datasets[dataset_idx][sample_idx]
class RandomMappingDataset(data.Dataset):
'''
Dataset wrapper to randomly mapping indices to original order.
Will also enlarge the length
'''
def __init__(self, ds, scale=200, **kwargs):
self.wrapped_data = ds
self.scale = scale
def __len__(self):
return len(self.wrapped_data) * self.scale
def __getitem__(self, index):
rng = random.Random(index)
rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
index = rng.randint(len(self.wrapped_data))
return self.wrapped_data[index]
class RandomDataset(data.Dataset):
'''
Dataset wrapper to randomly mapping indices to original order.
The indices are pre-processed.
Will also enlarge the length
'''
def __init__(self, ds, scale=200, seed=131, **kwargs):
self.wrapped_data = ds
self.scale = scale
self.indices = np.random.default_rng(seed).permutation(np.array(range(len(ds))))
def __len__(self):
return len(self.wrapped_data) * self.scale
def __getitem__(self, index):
return self.wrapped_data[int(self.indices[index % len(self.wrapped_data)])]
class BlockedRandomSplitDataset(data.Dataset):
'''
Dataset wrapper to access a subset of another dataset.
Use block algorithm to reduce memory.
In each block, using the `indices` items.
'''
def __init__(self, ds, indices, block_size, **kwargs):
if type(indices) is not np.ndarray:
indices = np.array(indices)
indices = np.sort(indices)
self.block_size = block_size
self.wrapped_data = ds
self.wrapped_data_len = len(ds)
self.indices = indices
self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size))
def __len__(self):
return self.len
def __getitem__(self, index):
return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)]]
class AlterDataset(IterableDataset):
def __init__(self, datasets, weights=None, seed=0, batch_from_same_dataset=False, batch_size=1):
super().__init__()
self.seed = seed
self.datasets = datasets
self.batch_from_same_dataset = batch_from_same_dataset
self.batch_size = batch_size # only used when batch_from_same_dataset is True
if weights is None:
self.weights = [1. / len(self.datasets)] * len(self.datasets)
else:
s = sum(weights)
self.weights = [w / s for w in weights]
def __iter__(self):
iterators = [iter(d) for d in self.datasets]
# Assume that all datasets iterate follow the seed (mprank, seed, dataloader-worker-id(0 if not used in dataloader)), auto-detect at iter()
try:
from sat.mpu import get_data_parallel_rank
dp_rank = get_data_parallel_rank()
except Exception:
dp_rank = 0
if self.batch_from_same_dataset:
rng = np.random.default_rng(seed=[self.seed])
else:
rng = np.random.default_rng(seed=[dp_rank, self.seed])
# sampling according to weights from streaming data
while True:
index = rng.choice(len(iterators), p=self.weights)
# if stop iteration, remove the iterator
try:
if self.batch_from_same_dataset:
# we need to make sure the consecutive batch_size samples are from the same iterable dataset.
# but accumulate grad does not work.
for i in range(self.batch_size - 1):
yield next(iterators[index])
yield next(iterators[index])
except StopIteration:
del iterators[index]
del self.weights[index]
if len(iterators) == 0:
break
s = sum(self.weights)
self.weights = [w / s for w in self.weights]
from sat.helpers import print_rank0
print_rank0(f'AlterDataset: remove a dataset, {len(iterators)} left.')