in src/nanotron/data/samplers.py [0:0]
def __post_init__(self):
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
# Sanity checks.
if self.total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(self.total_samples))
if self.consumed_samples >= self.total_samples:
raise RuntimeError("no samples left to consume: {}, {}".format(self.consumed_samples, self.total_samples))
if self.micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {self.micro_batch_size}")
if self.data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {self.data_parallel_size}")
if self.data_parallel_rank >= self.data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
self.data_parallel_rank, self.data_parallel_size
)
)
if self.global_batch_size % (self.micro_batch_size * self.data_parallel_size) != 0:
raise RuntimeError(
f"`global_batch_size` ({self.global_batch_size}) is not divisible by "
f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size "
f"({self.data_parallel_size})`"
)
if self.pad_samples_to_global_batch_size and self.global_batch_size is None:
raise RuntimeError(
"`pad_samples_to_global_batch_size` can be `True` only when "
"`global_batch_size` is set to an integer value"
)
log_rank(
f"Instantiating MegatronPretrainingSampler with total_samples: {self.total_samples} and consumed_samples: {self.consumed_samples}",
logger=logger,
level=logging.INFO,
rank=0,
)