chatlearn/data/sampler.py (284 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""data sampler."""
from itertools import chain
import torch
from chatlearn.utils import utils
class SingleDataSampler:
"""SingleDataSampler"""
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, dynamic_batch_size_flag=False, drop_last=False):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.remainder = (total_samples - consumed_samples) % data_parallel_size \
if dynamic_batch_size_flag else 0
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size + self.remainder
self.drop_last = drop_last
self.data_parallel_size = data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_batch_size_plus = self.data_parallel_rank if self.data_parallel_rank < self.remainder else self.remainder
start_idx = self.data_parallel_rank * self.micro_batch_size + start_batch_size_plus
batch_size_plus = 1 if self.data_parallel_rank < self.remainder else 0
batch_size = self.micro_batch_size + batch_size_plus
end_idx = start_idx + batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
indices = utils.split_index(len(batch), self.data_parallel_size)
start_idx, end_idx = indices[self.data_parallel_rank]
yield batch[start_idx:end_idx]
class EpisodeDataSampler:
"""EpisodeDataSampler"""
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, sample_per_episode, drop_last=False):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
self.data_parallel_size = data_parallel_size
if self.drop_last:
last_samples = self.total_samples % self.micro_batch_times_data_parallel_size
assert self.total_samples > last_samples, \
'total_samples are not enough to perform drop_last!'
self.total_samples -= last_samples
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
if self.consumed_samples >= self.total_samples:
self.consumed_samples = self.consumed_samples % self.total_samples
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
self.episode_offset = 0
self.sample_per_episode = sample_per_episode
def __len__(self):
return self.total_samples
def get_start_end_idx(self, batch):
indices = utils.split_index(len(batch), self.data_parallel_size)
return indices[self.data_parallel_rank]
def iter_internal(self, batch):
# for cycle purpose
if self.consumed_samples >= self.total_samples:
self.consumed_samples = self.consumed_samples % self.total_samples
for idx in chain(range(self.consumed_samples, self.total_samples), range(self.consumed_samples)):
batch.append(idx)
self.episode_offset += 1
self.consumed_samples += 1
if len(batch) == self.micro_batch_times_data_parallel_size or \
self.episode_offset == self.sample_per_episode:
return True
return False
def __iter__(self):
batch = []
while True:
# Last batch will be dropped if drop_last is set True
batch_gen_flag = self.iter_internal(batch)
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not batch_gen_flag:
# wrap it to sample_per_episode
batch_gen_flag = self.iter_internal(batch)
if batch_gen_flag:
start_idx, end_idx = self.get_start_end_idx(batch)
yield batch[start_idx:end_idx]
batch = []
if self.episode_offset == self.sample_per_episode:
self.episode_offset = 0
class RLHFSingleSampler:
"""RLHF sampler for a dataset.
"""
def __init__(
self,
total_samples,
consumed_samples,
batch_size,
shuffle=True,
seed=0,
drop_last="cycle"):
"""
drop_last:
"drop": drop last
"retain": return remaining samples
"cycle": loop back to the beginning
"""
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.actual_samples = total_samples // batch_size * batch_size if self.drop_last == "drop" else total_samples
self.offset = consumed_samples % self.actual_samples
self.curr_epoch = consumed_samples // self.actual_samples
self.update_random_idxs()
def update_random_idxs(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.curr_epoch + self.seed)
self.random_idx = torch.randperm(self.total_samples, generator=g).tolist()
else:
self.random_idx = list(range(self.total_samples))
def get_next(self, num):
batch = []
if self.drop_last == "drop":
assert num <= self.total_samples, 'drop mode does not support batch_size larger than dataset size'
if num > self.total_samples - self.offset:
self.offset = 0
self.curr_epoch += 1
self.update_random_idxs()
batch.extend(self.random_idx[self.offset : self.offset + num])
self.offset = self.offset + num
return batch
elif self.drop_last == "retain":
if num > self.total_samples - self.offset:
batch.extend(self.random_idx[self.offset : self.total_samples])
self.offset = 0
self.curr_epoch += 1
self.update_random_idxs()
else:
batch.extend(self.random_idx[self.offset : self.offset + num])
self.offset = self.offset + num
return batch
else:
while num >= self.total_samples - self.offset:
batch.extend(self.random_idx[self.offset : self.total_samples])
num -= (self.total_samples - self.offset)
self.offset = 0
self.curr_epoch += 1
self.update_random_idxs()
batch.extend(self.random_idx[self.offset : self.offset + num])
self.offset = self.offset + num
return batch
class MultiDatasetSampler:
"""RLHF sampler for multiple datasets.
"""
def __init__(
self,
dataset_sizes,
sample_per_episode,
data_ratio=None,
consumed_samples=0,
num_inference_per_prompt=1,
shuffle=True,
seed=None,
is_eval=False,
init_shuffle_prompt=0,
data_parallel_rank=0,
data_parallel_size=1,
drop_last="cycle",
data_rerank=False
):
"""
drop_last:
"drop": drop last
"retain": return remaining samples
"cycle": loop back to the beginning
"""
self.dataset_sizes = dataset_sizes
self.dataset_num = len(dataset_sizes)
self.data_parallel_size = data_parallel_size
self.data_parallel_rank = data_parallel_rank
self.sample_per_episode = sample_per_episode
self.consumed_samples = consumed_samples
self.is_eval = is_eval
self.num_inference_per_prompt = num_inference_per_prompt
self.shuffle = shuffle
self.seeds = [0] * self.dataset_num if seed is None else [seed] * self.dataset_num
if self.dataset_num == 1:
self.drop_last = drop_last
else:
self.drop_last = "cycle"
self.data_rerank = data_rerank
assert init_shuffle_prompt == 0, "init_shuffle_prompt=1, 2 is not supported yet"
assert self.consumed_samples % self.num_inference_per_prompt == 0, "consumed samples must be integer multiple of num_inference_per_prompt"
assert self.sample_per_episode % self.num_inference_per_prompt == 0, "sample_per_episode must be integer multiple of num_inference_per_prompt"
if not self.is_eval:
if data_ratio is None:
self.data_ratio = [1] * self.dataset_num
elif isinstance(data_ratio, int):
self.data_ratio = [data_ratio] * self.dataset_num
elif isinstance(data_ratio, list):
assert len(data_ratio) == self.dataset_num, (
"expect data_ratio to be a list with the same length as the number of datasets, "
f"got {len(data_ratio)} and {self.dataset_num}."
)
self.data_ratio = data_ratio
else:
raise TypeError(f"unexpected data_ratio type {type(data_ratio)}, expect int or List.")
consumed_each, self.dataset_remains = self.cal_consumed_each(self.consumed_samples // self.num_inference_per_prompt, self.data_ratio)
self.samplers = [
RLHFSingleSampler(
self.dataset_sizes[i],
consumed_each[i],
batch_size=self.sample_per_episode // self.num_inference_per_prompt if self.dataset_num == 1 else self.data_ratio[i],
shuffle=self.shuffle,
seed=self.seeds[i],
drop_last=self.drop_last
)
for i in range(self.dataset_num)
]
def cal_consumed_each(self, consumed_samples, data_ratio):
multiples = consumed_samples // sum(data_ratio)
consumed_each = [r * multiples for r in data_ratio]
remains = consumed_samples % sum(data_ratio)
i = 0
dataset_remains = data_ratio[:]
while remains != 0:
if i == 0:
dataset_remains = data_ratio[:]
dataset_remains[i] -= min(remains, data_ratio[i])
consumed_each[i] += min(remains, data_ratio[i])
remains -= min(remains, data_ratio[i])
i = (i + 1) % self.dataset_num
return consumed_each, dataset_remains
def repeat(self, data, interleave=False):
if interleave:
res = []
for d in data:
res.extend([d] * self.num_inference_per_prompt)
return res
else:
return data * self.num_inference_per_prompt
def __iter__(self):
self.remainder = self.sample_per_episode % self.data_parallel_size
batch_size_list = [self.sample_per_episode // self.data_parallel_size + 1] * self.remainder + \
[self.sample_per_episode // self.data_parallel_size] * (self.data_parallel_size - self.remainder)
start, end = sum(batch_size_list[:self.data_parallel_rank]), sum(batch_size_list[:self.data_parallel_rank + 1])
if self.is_eval:
assert self.sample_per_episode <= sum(self.dataset_sizes), "eval dataset size must be larger than sample_per_episode"
idxes = []
for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
idxes.extend([(dataset_idx, j, (len(idxes) + j)) for j in range(dataset_size)])
for i in range(0, len(idxes), self.sample_per_episode):
episode_samples = idxes[i : i + self.sample_per_episode]
yield episode_samples[start : end]
else:
num_samples = self.sample_per_episode // self.num_inference_per_prompt
while True:
if self.drop_last == "drop":
batch = self.samplers[0].get_next(num_samples)
batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
batch_idxes = batch_idxes[start : end]
elif self.drop_last == "retain":
batch = self.samplers[0].get_next(num_samples)
batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
new_batch_size_list = [len(batch_idxes) // self.data_parallel_size] * self.data_parallel_size
for i in range(len(batch_idxes) % self.data_parallel_size):
new_batch_size_list[i] += 1
batch_idxes = \
batch_idxes[sum(new_batch_size_list[:self.data_parallel_rank]) : sum(new_batch_size_list[:self.data_parallel_rank + 1])]
else:
batch_idxes = []
dataset_id = 0
while len(batch_idxes) != num_samples:
data_num = min(self.dataset_remains[dataset_id], num_samples - len(batch_idxes))
self.dataset_remains[dataset_id] -= data_num
batch = self.samplers[dataset_id].get_next(data_num)
batch_idxes.extend([(dataset_id, batch[i], (len(batch_idxes) + i)) for i in range(data_num)])
dataset_id = (dataset_id + 1) % self.dataset_num
if self.dataset_remains == [0] * self.dataset_num:
self.dataset_remains = self.data_ratio[:]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
batch_idxes = batch_idxes[start : end]
yield batch_idxes