in timm/data/naflex_dataset.py [0:0]
def _prepare_epoch_batches(self, epoch: int):
"""
Prepares the batches for the current epoch by:
1. Shuffling the full dataset indices (using epoch seed).
2. Applying padding if in distributed mode.
3. Selecting indices for the current rank.
4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
5. Assigning the rank's indices to the shuffled batches.
"""
g = torch.Generator()
g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
# 1. Get shuffled global indices
total_len = len(self.base_dataset)
if self.shuffle:
all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
else:
all_indices_shuffled = list(range(total_len))
# 2. Apply padding for distributed mode
indices_for_ranks = all_indices_shuffled
if self.distributed and self.world_size > 1:
padded_total_len = self._padded_samples_per_rank * self.world_size
if padded_total_len > total_len:
pad_size = padded_total_len - total_len
# Repeat initial elements from the *shuffled* list for padding
indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
# Ensure length matches expectation
if len(indices_for_ranks) != padded_total_len:
raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
# 3. Select indices for the current rank
if self.distributed and self.world_size > 1:
indices_this_rank = indices_for_ranks[self.rank::self.world_size]
else: # Non-distributed or world_size=1
indices_this_rank = indices_for_ranks
# Sanity check length
if len(indices_this_rank) != self._padded_samples_per_rank:
# This might happen if canonical schedule generation had warnings/issues
warnings.warn(
f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
f"Epoch generation might be inconsistent."
)
# Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
# Using min() prevents IndexError later if indices are fewer than expected.
effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
indices_this_rank = indices_this_rank[:effective_samples_this_rank]
else:
effective_samples_this_rank = self._padded_samples_per_rank
# 4. Shuffle the order of the canonical batch schedule for this epoch
if self.shuffle:
schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
else:
shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
# 5. Assign indices to the shuffled batches
self._epoch_batches = []
idx_pos = 0
scheduled_samples_count = 0
for seq_len, bs in shuffled_schedule:
# Ensure we don't try to grab more indices than available for the rank
actual_bs = min(bs, effective_samples_this_rank - idx_pos)
if actual_bs <= 0:
if scheduled_samples_count < effective_samples_this_rank:
# This indicates mismatch between schedule total and actual samples
warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
break # Stop if no more indices or batch size is zero
batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
self._epoch_batches.append((seq_len, batch_indices))
idx_pos += actual_bs
scheduled_samples_count += actual_bs
# Final check
if scheduled_samples_count != effective_samples_this_rank:
warnings.warn(
f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
f"but expected {effective_samples_this_rank} effective samples this epoch. "
f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
)