def _create_canonical_schedule()

in timm/data/naflex_dataset.py [0:0]


    def _create_canonical_schedule(self):
        """
        Calculates the canonical batch schedule (seq_len, batch_size pairs)
        based on the dataset size, padded for distributed training.
        This schedule is the *same* for all ranks and ensures consistent
        epoch length. It is calculated once during initialization.
        """
        total_len = len(self.base_dataset)
        padded_total_len = total_len
        num_samples_per_rank = total_len

        if self.distributed and self.world_size > 1:
            # Calculate padding needed for even distribution
            if total_len % self.world_size != 0:
                 pad_size = self.world_size - (total_len % self.world_size)
                 padded_total_len += pad_size
                 print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
            else:
                 pad_size = 0

            if padded_total_len % self.world_size != 0:
                 # This should not happen with the padding logic, but safeguard
                 raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")

            num_samples_per_rank = padded_total_len // self.world_size
        elif self.distributed and self.world_size <= 1:
             # Distributed flag set but world_size is 1, treat as non-distributed
             pass # num_samples_per_rank remains total_len

        self._padded_samples_per_rank = num_samples_per_rank

        if num_samples_per_rank == 0:
             self._canonical_batch_schedule = []
             self._num_batches_per_rank = 0
             return

        # Use a fixed seed for generating the canonical schedule structure
        g = torch.Generator()
        g.manual_seed(self.seed) # Use base seed, NOT epoch seed

        current_schedule: List[Tuple[int, int]] = []
        remaining_samples = num_samples_per_rank
        total_scheduled_samples = 0

        while remaining_samples > 0:
            # Sample sequence length deterministically based on base seed
            seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
            seq_len = self.seq_lens[seq_idx]

            # Calculate batch size
            batch_size = calculate_naflex_batch_size(
                tokens_per_batch=self.max_tokens_per_batch,
                seq_len=seq_len,
                # max_size should be remaining_samples to avoid overshooting
                max_size=remaining_samples,
                divisor=self.batch_divisor,
                rounding='floor',
            )
            # Ensure batch size is positive and doesn't exceed remaining samples
            batch_size = max(1, batch_size)
            batch_size = min(batch_size, remaining_samples)

            if batch_size <= 0:
                 warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
                 break # Avoid infinite loop if something goes wrong

            current_schedule.append((seq_len, batch_size))
            remaining_samples -= batch_size
            total_scheduled_samples += batch_size

        # Sanity check: Ensure the schedule covers all samples for the rank
        if total_scheduled_samples != num_samples_per_rank:
            warnings.warn(
                f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
                f"but expected {num_samples_per_rank} samples per rank. "
                f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
                f"Check parameters. Remaining samples: {remaining_samples}"
            )
            # Adjust if needed? Could add a final small batch, but might violate constraints.
            # Current behavior: some samples might be dropped if schedule logic fails.

        self._canonical_batch_schedule = current_schedule
        self._num_batches_per_rank = len(current_schedule)
        print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")