chatlearn/data/sampler.py (284 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """data sampler.""" from itertools import chain import torch from chatlearn.utils import utils class SingleDataSampler: """SingleDataSampler""" def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, dynamic_batch_size_flag=False, drop_last=False): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.remainder = (total_samples - consumed_samples) % data_parallel_size \ if dynamic_batch_size_flag else 0 self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size + self.remainder self.drop_last = drop_last self.data_parallel_size = data_parallel_size # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) assert self.consumed_samples < self.total_samples, \ 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) def __len__(self): return self.total_samples def get_start_end_idx(self): start_batch_size_plus = self.data_parallel_rank if self.data_parallel_rank < self.remainder else self.remainder start_idx = self.data_parallel_rank * self.micro_batch_size + start_batch_size_plus batch_size_plus = 1 if self.data_parallel_rank < self.remainder else 0 batch_size = self.micro_batch_size + batch_size_plus end_idx = start_idx + batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: indices = utils.split_index(len(batch), self.data_parallel_size) start_idx, end_idx = indices[self.data_parallel_rank] yield batch[start_idx:end_idx] class EpisodeDataSampler: """EpisodeDataSampler""" def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, sample_per_episode, drop_last=False): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last self.data_parallel_size = data_parallel_size if self.drop_last: last_samples = self.total_samples % self.micro_batch_times_data_parallel_size assert self.total_samples > last_samples, \ 'total_samples are not enough to perform drop_last!' self.total_samples -= last_samples # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) if self.consumed_samples >= self.total_samples: self.consumed_samples = self.consumed_samples % self.total_samples assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) self.episode_offset = 0 self.sample_per_episode = sample_per_episode def __len__(self): return self.total_samples def get_start_end_idx(self, batch): indices = utils.split_index(len(batch), self.data_parallel_size) return indices[self.data_parallel_rank] def iter_internal(self, batch): # for cycle purpose if self.consumed_samples >= self.total_samples: self.consumed_samples = self.consumed_samples % self.total_samples for idx in chain(range(self.consumed_samples, self.total_samples), range(self.consumed_samples)): batch.append(idx) self.episode_offset += 1 self.consumed_samples += 1 if len(batch) == self.micro_batch_times_data_parallel_size or \ self.episode_offset == self.sample_per_episode: return True return False def __iter__(self): batch = [] while True: # Last batch will be dropped if drop_last is set True batch_gen_flag = self.iter_internal(batch) # Check the last partial batch and see drop_last is set if len(batch) > 0 and not batch_gen_flag: # wrap it to sample_per_episode batch_gen_flag = self.iter_internal(batch) if batch_gen_flag: start_idx, end_idx = self.get_start_end_idx(batch) yield batch[start_idx:end_idx] batch = [] if self.episode_offset == self.sample_per_episode: self.episode_offset = 0 class RLHFSingleSampler: """RLHF sampler for a dataset. """ def __init__( self, total_samples, consumed_samples, batch_size, shuffle=True, seed=0, drop_last="cycle"): """ drop_last: "drop": drop last "retain": return remaining samples "cycle": loop back to the beginning """ self.drop_last = drop_last self.shuffle = shuffle self.seed = seed self.total_samples = total_samples self.consumed_samples = consumed_samples self.actual_samples = total_samples // batch_size * batch_size if self.drop_last == "drop" else total_samples self.offset = consumed_samples % self.actual_samples self.curr_epoch = consumed_samples // self.actual_samples self.update_random_idxs() def update_random_idxs(self): if self.shuffle: g = torch.Generator() g.manual_seed(self.curr_epoch + self.seed) self.random_idx = torch.randperm(self.total_samples, generator=g).tolist() else: self.random_idx = list(range(self.total_samples)) def get_next(self, num): batch = [] if self.drop_last == "drop": assert num <= self.total_samples, 'drop mode does not support batch_size larger than dataset size' if num > self.total_samples - self.offset: self.offset = 0 self.curr_epoch += 1 self.update_random_idxs() batch.extend(self.random_idx[self.offset : self.offset + num]) self.offset = self.offset + num return batch elif self.drop_last == "retain": if num > self.total_samples - self.offset: batch.extend(self.random_idx[self.offset : self.total_samples]) self.offset = 0 self.curr_epoch += 1 self.update_random_idxs() else: batch.extend(self.random_idx[self.offset : self.offset + num]) self.offset = self.offset + num return batch else: while num >= self.total_samples - self.offset: batch.extend(self.random_idx[self.offset : self.total_samples]) num -= (self.total_samples - self.offset) self.offset = 0 self.curr_epoch += 1 self.update_random_idxs() batch.extend(self.random_idx[self.offset : self.offset + num]) self.offset = self.offset + num return batch class MultiDatasetSampler: """RLHF sampler for multiple datasets. """ def __init__( self, dataset_sizes, sample_per_episode, data_ratio=None, consumed_samples=0, num_inference_per_prompt=1, shuffle=True, seed=None, is_eval=False, init_shuffle_prompt=0, data_parallel_rank=0, data_parallel_size=1, drop_last="cycle", data_rerank=False ): """ drop_last: "drop": drop last "retain": return remaining samples "cycle": loop back to the beginning """ self.dataset_sizes = dataset_sizes self.dataset_num = len(dataset_sizes) self.data_parallel_size = data_parallel_size self.data_parallel_rank = data_parallel_rank self.sample_per_episode = sample_per_episode self.consumed_samples = consumed_samples self.is_eval = is_eval self.num_inference_per_prompt = num_inference_per_prompt self.shuffle = shuffle self.seeds = [0] * self.dataset_num if seed is None else [seed] * self.dataset_num if self.dataset_num == 1: self.drop_last = drop_last else: self.drop_last = "cycle" self.data_rerank = data_rerank assert init_shuffle_prompt == 0, "init_shuffle_prompt=1, 2 is not supported yet" assert self.consumed_samples % self.num_inference_per_prompt == 0, "consumed samples must be integer multiple of num_inference_per_prompt" assert self.sample_per_episode % self.num_inference_per_prompt == 0, "sample_per_episode must be integer multiple of num_inference_per_prompt" if not self.is_eval: if data_ratio is None: self.data_ratio = [1] * self.dataset_num elif isinstance(data_ratio, int): self.data_ratio = [data_ratio] * self.dataset_num elif isinstance(data_ratio, list): assert len(data_ratio) == self.dataset_num, ( "expect data_ratio to be a list with the same length as the number of datasets, " f"got {len(data_ratio)} and {self.dataset_num}." ) self.data_ratio = data_ratio else: raise TypeError(f"unexpected data_ratio type {type(data_ratio)}, expect int or List.") consumed_each, self.dataset_remains = self.cal_consumed_each(self.consumed_samples // self.num_inference_per_prompt, self.data_ratio) self.samplers = [ RLHFSingleSampler( self.dataset_sizes[i], consumed_each[i], batch_size=self.sample_per_episode // self.num_inference_per_prompt if self.dataset_num == 1 else self.data_ratio[i], shuffle=self.shuffle, seed=self.seeds[i], drop_last=self.drop_last ) for i in range(self.dataset_num) ] def cal_consumed_each(self, consumed_samples, data_ratio): multiples = consumed_samples // sum(data_ratio) consumed_each = [r * multiples for r in data_ratio] remains = consumed_samples % sum(data_ratio) i = 0 dataset_remains = data_ratio[:] while remains != 0: if i == 0: dataset_remains = data_ratio[:] dataset_remains[i] -= min(remains, data_ratio[i]) consumed_each[i] += min(remains, data_ratio[i]) remains -= min(remains, data_ratio[i]) i = (i + 1) % self.dataset_num return consumed_each, dataset_remains def repeat(self, data, interleave=False): if interleave: res = [] for d in data: res.extend([d] * self.num_inference_per_prompt) return res else: return data * self.num_inference_per_prompt def __iter__(self): self.remainder = self.sample_per_episode % self.data_parallel_size batch_size_list = [self.sample_per_episode // self.data_parallel_size + 1] * self.remainder + \ [self.sample_per_episode // self.data_parallel_size] * (self.data_parallel_size - self.remainder) start, end = sum(batch_size_list[:self.data_parallel_rank]), sum(batch_size_list[:self.data_parallel_rank + 1]) if self.is_eval: assert self.sample_per_episode <= sum(self.dataset_sizes), "eval dataset size must be larger than sample_per_episode" idxes = [] for dataset_idx, dataset_size in enumerate(self.dataset_sizes): idxes.extend([(dataset_idx, j, (len(idxes) + j)) for j in range(dataset_size)]) for i in range(0, len(idxes), self.sample_per_episode): episode_samples = idxes[i : i + self.sample_per_episode] yield episode_samples[start : end] else: num_samples = self.sample_per_episode // self.num_inference_per_prompt while True: if self.drop_last == "drop": batch = self.samplers[0].get_next(num_samples) batch_idxes = [(0, batch[i], i) for i in range(len(batch))] batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank) batch_idxes = batch_idxes[start : end] elif self.drop_last == "retain": batch = self.samplers[0].get_next(num_samples) batch_idxes = [(0, batch[i], i) for i in range(len(batch))] batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank) new_batch_size_list = [len(batch_idxes) // self.data_parallel_size] * self.data_parallel_size for i in range(len(batch_idxes) % self.data_parallel_size): new_batch_size_list[i] += 1 batch_idxes = \ batch_idxes[sum(new_batch_size_list[:self.data_parallel_rank]) : sum(new_batch_size_list[:self.data_parallel_rank + 1])] else: batch_idxes = [] dataset_id = 0 while len(batch_idxes) != num_samples: data_num = min(self.dataset_remains[dataset_id], num_samples - len(batch_idxes)) self.dataset_remains[dataset_id] -= data_num batch = self.samplers[dataset_id].get_next(data_num) batch_idxes.extend([(dataset_id, batch[i], (len(batch_idxes) + i)) for i in range(data_num)]) dataset_id = (dataset_id + 1) % self.dataset_num if self.dataset_remains == [0] * self.dataset_num: self.dataset_remains = self.data_ratio[:] batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank) batch_idxes = batch_idxes[start : end] yield batch_idxes