in timm/data/naflex_dataset.py [0:0]
def _create_canonical_schedule(self):
"""
Calculates the canonical batch schedule (seq_len, batch_size pairs)
based on the dataset size, padded for distributed training.
This schedule is the *same* for all ranks and ensures consistent
epoch length. It is calculated once during initialization.
"""
total_len = len(self.base_dataset)
padded_total_len = total_len
num_samples_per_rank = total_len
if self.distributed and self.world_size > 1:
# Calculate padding needed for even distribution
if total_len % self.world_size != 0:
pad_size = self.world_size - (total_len % self.world_size)
padded_total_len += pad_size
print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
else:
pad_size = 0
if padded_total_len % self.world_size != 0:
# This should not happen with the padding logic, but safeguard
raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
num_samples_per_rank = padded_total_len // self.world_size
elif self.distributed and self.world_size <= 1:
# Distributed flag set but world_size is 1, treat as non-distributed
pass # num_samples_per_rank remains total_len
self._padded_samples_per_rank = num_samples_per_rank
if num_samples_per_rank == 0:
self._canonical_batch_schedule = []
self._num_batches_per_rank = 0
return
# Use a fixed seed for generating the canonical schedule structure
g = torch.Generator()
g.manual_seed(self.seed) # Use base seed, NOT epoch seed
current_schedule: List[Tuple[int, int]] = []
remaining_samples = num_samples_per_rank
total_scheduled_samples = 0
while remaining_samples > 0:
# Sample sequence length deterministically based on base seed
seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
seq_len = self.seq_lens[seq_idx]
# Calculate batch size
batch_size = calculate_naflex_batch_size(
tokens_per_batch=self.max_tokens_per_batch,
seq_len=seq_len,
# max_size should be remaining_samples to avoid overshooting
max_size=remaining_samples,
divisor=self.batch_divisor,
rounding='floor',
)
# Ensure batch size is positive and doesn't exceed remaining samples
batch_size = max(1, batch_size)
batch_size = min(batch_size, remaining_samples)
if batch_size <= 0:
warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
break # Avoid infinite loop if something goes wrong
current_schedule.append((seq_len, batch_size))
remaining_samples -= batch_size
total_scheduled_samples += batch_size
# Sanity check: Ensure the schedule covers all samples for the rank
if total_scheduled_samples != num_samples_per_rank:
warnings.warn(
f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
f"but expected {num_samples_per_rank} samples per rank. "
f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
f"Check parameters. Remaining samples: {remaining_samples}"
)
# Adjust if needed? Could add a final small batch, but might violate constraints.
# Current behavior: some samples might be dropped if schedule logic fails.
self._canonical_batch_schedule = current_schedule
self._num_batches_per_rank = len(current_schedule)
print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")