in src/nanotron/data/samplers.py [0:0]
def __iter__(self):
batch = []
batch_idx = 0
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
log_rank(
f"DP {self.data_parallel_rank} batch {batch_idx} {batch[start_idx:end_idx]} self.consumed_samples {self.consumed_samples}",
logger=logger,
level=logging.DEBUG,
)
# self.last_consumed_sample_all_ranks = batch[-1] # = self.consumed_samples?
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch[start_idx:end_idx]
batch = []
batch_idx += 1
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
if self.pad_samples_to_global_batch_size:
for i in range(
self.data_parallel_rank, self.global_batch_size, self.micro_batch_times_data_parallel_size
):
indices = [batch[j] for j in range(i, max(len(batch), i + self.micro_batch_size))]
num_pad = self.micro_batch_size - len(indices)
indices = indices + [-1] * num_pad
yield indices
else:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]