def _prepare_epoch_batches()

in timm/data/naflex_dataset.py [0:0]


    def _prepare_epoch_batches(self, epoch: int):
        """
        Prepares the batches for the current epoch by:
        1. Shuffling the full dataset indices (using epoch seed).
        2. Applying padding if in distributed mode.
        3. Selecting indices for the current rank.
        4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
        5. Assigning the rank's indices to the shuffled batches.
        """
        g = torch.Generator()
        g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling

        # 1. Get shuffled global indices
        total_len = len(self.base_dataset)
        if self.shuffle:
            all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
        else:
            all_indices_shuffled = list(range(total_len))

        # 2. Apply padding for distributed mode
        indices_for_ranks = all_indices_shuffled
        if self.distributed and self.world_size > 1:
            padded_total_len = self._padded_samples_per_rank * self.world_size
            if padded_total_len > total_len:
                pad_size = padded_total_len - total_len
                # Repeat initial elements from the *shuffled* list for padding
                indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
            # Ensure length matches expectation
            if len(indices_for_ranks) != padded_total_len:
                 raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")

        # 3. Select indices for the current rank
        if self.distributed and self.world_size > 1:
            indices_this_rank = indices_for_ranks[self.rank::self.world_size]
        else: # Non-distributed or world_size=1
            indices_this_rank = indices_for_ranks

        # Sanity check length
        if len(indices_this_rank) != self._padded_samples_per_rank:
             # This might happen if canonical schedule generation had warnings/issues
             warnings.warn(
                 f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
                 f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
                 f"Epoch generation might be inconsistent."
              )
             # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
             # Using min() prevents IndexError later if indices are fewer than expected.
             effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
             indices_this_rank = indices_this_rank[:effective_samples_this_rank]

        else:
             effective_samples_this_rank = self._padded_samples_per_rank

        # 4. Shuffle the order of the canonical batch schedule for this epoch
        if self.shuffle:
            schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
            shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
        else:
            shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order

        # 5. Assign indices to the shuffled batches
        self._epoch_batches = []
        idx_pos = 0
        scheduled_samples_count = 0
        for seq_len, bs in shuffled_schedule:
            # Ensure we don't try to grab more indices than available for the rank
            actual_bs = min(bs, effective_samples_this_rank - idx_pos)
            if actual_bs <= 0:
                 if scheduled_samples_count < effective_samples_this_rank:
                     # This indicates mismatch between schedule total and actual samples
                     warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
                 break # Stop if no more indices or batch size is zero

            batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
            self._epoch_batches.append((seq_len, batch_indices))
            idx_pos += actual_bs
            scheduled_samples_count += actual_bs

        # Final check
        if scheduled_samples_count != effective_samples_this_rank:
             warnings.warn(
                f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
                f"but expected {effective_samples_this_rank} effective samples this epoch. "
                f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
             )