def __iter__()

in timm/data/naflex_dataset.py [0:0]


    def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
        """Iterates through pre-calculated batches for the current epoch.

        Yields:
            Tuple of (input_dict, targets) for each batch.
        """
        worker_info = torch.utils.data.get_worker_info()
        num_workers = worker_info.num_workers if worker_info else 1
        worker_id = worker_info.id if worker_info else 0

        # Distribute pre-calculated batches among workers for this rank
        # Each worker processes a slice of the batches prepared in _prepare_epoch_batches
        batches_for_worker = self._epoch_batches[worker_id::num_workers]
        for seq_len, indices in batches_for_worker:
            if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
                 continue

            # Select patch size for this batch
            patch_idx = 0
            if self.variable_patch_size:
                # Use torch multinomial for weighted random choice
                patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item()

            # Get the pre-initialized transform and patchifier using patch_idx
            transform_key = (seq_len, patch_idx)
            transform = self.transforms.get(transform_key)
            batch_patchifier = self.patchifiers[patch_idx]

            batch_imgs = []
            batch_targets = []
            for idx in indices:
                try:
                    # Get original image and label from map-style dataset
                    img, label = self.base_dataset[idx]

                    # Apply transform if available
                    # Handle cases where transform might return None or fail
                    processed_img = transform(img) if transform else img
                    if processed_img is None:
                        warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
                        continue

                    batch_imgs.append(processed_img)
                    batch_targets.append(label)

                except IndexError:
                     warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
                     continue
                except Exception as e:
                    # Log other potential errors during data loading/processing
                    warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
                    continue # Skip problematic sample

            if self.mixup_fn is not None:
                batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)

            batch_imgs = [batch_patchifier(img) for img in batch_imgs]
            batch_samples = list(zip(batch_imgs, batch_targets))
            if batch_samples: # Only yield if we successfully processed samples
                # Collate the processed samples into a batch
                yield self.collate_fns[seq_len](batch_samples)