in chatlearn/data/sampler.py [0:0]
def __iter__(self):
self.remainder = self.sample_per_episode % self.data_parallel_size
batch_size_list = [self.sample_per_episode // self.data_parallel_size + 1] * self.remainder + \
[self.sample_per_episode // self.data_parallel_size] * (self.data_parallel_size - self.remainder)
start, end = sum(batch_size_list[:self.data_parallel_rank]), sum(batch_size_list[:self.data_parallel_rank + 1])
if self.is_eval:
assert self.sample_per_episode <= sum(self.dataset_sizes), "eval dataset size must be larger than sample_per_episode"
idxes = []
for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
idxes.extend([(dataset_idx, j, (len(idxes) + j)) for j in range(dataset_size)])
for i in range(0, len(idxes), self.sample_per_episode):
episode_samples = idxes[i : i + self.sample_per_episode]
yield episode_samples[start : end]
else:
num_samples = self.sample_per_episode // self.num_inference_per_prompt
while True:
if self.drop_last == "drop":
batch = self.samplers[0].get_next(num_samples)
batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
batch_idxes = batch_idxes[start : end]
elif self.drop_last == "retain":
batch = self.samplers[0].get_next(num_samples)
batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
new_batch_size_list = [len(batch_idxes) // self.data_parallel_size] * self.data_parallel_size
for i in range(len(batch_idxes) % self.data_parallel_size):
new_batch_size_list[i] += 1
batch_idxes = \
batch_idxes[sum(new_batch_size_list[:self.data_parallel_rank]) : sum(new_batch_size_list[:self.data_parallel_rank + 1])]
else:
batch_idxes = []
dataset_id = 0
while len(batch_idxes) != num_samples:
data_num = min(self.dataset_remains[dataset_id], num_samples - len(batch_idxes))
self.dataset_remains[dataset_id] -= data_num
batch = self.samplers[dataset_id].get_next(data_num)
batch_idxes.extend([(dataset_id, batch[i], (len(batch_idxes) + i)) for i in range(data_num)])
dataset_id = (dataset_id + 1) % self.dataset_num
if self.dataset_remains == [0] * self.dataset_num:
self.dataset_remains = self.data_ratio[:]
batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
batch_idxes = batch_idxes[start : end]
yield batch_idxes