def create_naflex_loader()

in timm/data/naflex_loader.py [0:0]


def create_naflex_loader(
        dataset,
        patch_size: Optional[Union[Tuple[int, int], int]] = None,
        patch_size_choices: Optional[List[int]] = None,
        patch_size_choice_probs: Optional[List[float]] = None,
        train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
        max_seq_len: int = 576,
        batch_size: int = 32,
        is_training: bool = False,
        mixup_fn: Optional[Callable] = None,

        no_aug: bool = False,
        re_prob: float = 0.,
        re_mode: str = 'const',
        re_count: int = 1,
        re_split: bool = False,
        train_crop_mode: Optional[str] = None,
        scale: Optional[Tuple[float, float]] = None,
        ratio: Optional[Tuple[float, float]] = None,
        hflip: float = 0.5,
        vflip: float = 0.,
        color_jitter: float = 0.4,
        color_jitter_prob: Optional[float] = None,
        grayscale_prob: float = 0.,
        gaussian_blur_prob: float = 0.,
        auto_augment: Optional[str] = None,
        num_aug_repeats: int = 0,
        num_aug_splits: int = 0,
        interpolation: str = 'bilinear',
        mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
        std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
        crop_pct: Optional[float] = None,
        crop_mode: Optional[str] = None,
        crop_border_pixels: Optional[int] = None,

        num_workers: int = 4,
        distributed: bool = False,
        rank: int = 0,
        world_size: int = 1,
        seed: int = 42,
        epoch: int = 0,
        use_prefetcher: bool = True,
        pin_memory: bool = True,
        img_dtype: torch.dtype = torch.float32,
        device: Union[str, torch.device] = torch.device('cuda'),
        persistent_workers: bool = True,
        worker_seeding: str = 'all',
    ) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]:
    """Create a data loader with dynamic sequence length sampling for training.

    Args:
        dataset: Dataset to load from.
        patch_size: Single patch size to use.
        patch_size_choices: List of patch sizes for variable patch size training.
        patch_size_choice_probs: Probabilities for each patch size choice.
        train_seq_lens: Training sequence lengths for dynamic batching.
        max_seq_len: Fixed sequence length for validation.
        batch_size: Batch size for validation and max training sequence length.
        is_training: Whether this is for training (enables dynamic batching).
        mixup_fn: Optional mixup function.
        no_aug: Disable augmentation.
        re_prob: Random erasing probability.
        re_mode: Random erasing mode.
        re_count: Maximum number of erasing rectangles.
        re_split: Random erasing split flag.
        train_crop_mode: Training crop mode.
        scale: Scale range for random resize crop.
        ratio: Aspect ratio range for random resize crop.
        hflip: Horizontal flip probability.
        vflip: Vertical flip probability.
        color_jitter: Color jitter factor.
        color_jitter_prob: Color jitter probability.
        grayscale_prob: Grayscale conversion probability.
        gaussian_blur_prob: Gaussian blur probability.
        auto_augment: AutoAugment policy.
        num_aug_repeats: Number of augmentation repeats.
        num_aug_splits: Number of augmentation splits.
        interpolation: Interpolation method.
        mean: Normalization mean values.
        std: Normalization standard deviation values.
        crop_pct: Crop percentage for validation.
        crop_mode: Crop mode.
        crop_border_pixels: Crop border pixels.
        num_workers: Number of data loading workers.
        distributed: Whether using distributed training.
        rank: Process rank for distributed training.
        world_size: Total number of processes.
        seed: Random seed.
        epoch: Starting epoch.
        use_prefetcher: Whether to use prefetching.
        pin_memory: Whether to pin memory.
        img_dtype: Image data type.
        device: Device to move tensors to.
        persistent_workers: Whether to use persistent workers.
        worker_seeding: Worker seeding mode.

    Returns:
        DataLoader or NaFlexPrefetchLoader instance.
    """

    if is_training:
        # For training, use the dynamic sequence length mechanism
        assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'

        transform_factory = partial(
            create_transform,
            is_training=True,
            no_aug=no_aug,
            train_crop_mode=train_crop_mode,
            scale=scale,
            ratio=ratio,
            hflip=hflip,
            vflip=vflip,
            color_jitter=color_jitter,
            color_jitter_prob=color_jitter_prob,
            grayscale_prob=grayscale_prob,
            gaussian_blur_prob=gaussian_blur_prob,
            auto_augment=auto_augment,
            interpolation=interpolation,
            mean=mean,
            std=std,
            crop_pct=crop_pct,
            crop_mode=crop_mode,
            crop_border_pixels=crop_border_pixels,
            re_prob=re_prob,
            re_mode=re_mode,
            re_count=re_count,
            use_prefetcher=use_prefetcher,
            naflex=True,
        )

        max_train_seq_len = max(train_seq_lens)
        max_tokens_per_batch = batch_size * max_train_seq_len

        if isinstance(dataset, torch.utils.data.IterableDataset):
            assert False, "IterableDataset Wrapper is a WIP"

        naflex_dataset = NaFlexMapDatasetWrapper(
            dataset,
            transform_factory=transform_factory,
            patch_size=patch_size,
            patch_size_choices=patch_size_choices,
            patch_size_choice_probs=patch_size_choice_probs,
            seq_lens=train_seq_lens,
            max_tokens_per_batch=max_tokens_per_batch,
            mixup_fn=mixup_fn,
            seed=seed,
            distributed=distributed,
            rank=rank,
            world_size=world_size,
            shuffle=True,
            epoch=epoch,
        )

        # NOTE: Collation is handled by the dataset wrapper for training
        loader = torch.utils.data.DataLoader(
            naflex_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            sampler=None,
            pin_memory=pin_memory,
            worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
            persistent_workers=persistent_workers
        )

        if use_prefetcher:
            loader = NaFlexPrefetchLoader(
                loader,
                mean=mean,
                std=std,
                img_dtype=img_dtype,
                device=device,
                re_prob=re_prob,
                re_mode=re_mode,
                re_count=re_count,
            )

    else:
        # For validation, use fixed sequence length (unchanged)
        dataset.transform = create_transform(
            is_training=False,
            interpolation=interpolation,
            mean=mean,
            std=std,
            # FIXME add crop args when sequence transforms support crop modes
            use_prefetcher=use_prefetcher,
            naflex=True,
            patch_size=patch_size,
            max_seq_len=max_seq_len,
            patchify=True,
        )

        # Create the collator
        collate_fn = NaFlexCollator(max_seq_len=max_seq_len)

        # Handle distributed training
        sampler = None
        if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
            # For validation, use OrderedDistributedSampler
            from timm.data.distributed_sampler import OrderedDistributedSampler
            sampler = OrderedDistributedSampler(dataset)

        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            sampler=sampler,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
            drop_last=False,
        )

        if use_prefetcher:
            loader = NaFlexPrefetchLoader(
                loader,
                mean=mean,
                std=std,
                img_dtype=img_dtype,
                device=device,
            )

    return loader