in timm/data/naflex_loader.py [0:0]
def create_naflex_loader(
dataset,
patch_size: Optional[Union[Tuple[int, int], int]] = None,
patch_size_choices: Optional[List[int]] = None,
patch_size_choice_probs: Optional[List[float]] = None,
train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
max_seq_len: int = 576,
batch_size: int = 32,
is_training: bool = False,
mixup_fn: Optional[Callable] = None,
no_aug: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_split: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: float = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
num_aug_repeats: int = 0,
num_aug_splits: int = 0,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
num_workers: int = 4,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
seed: int = 42,
epoch: int = 0,
use_prefetcher: bool = True,
pin_memory: bool = True,
img_dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = torch.device('cuda'),
persistent_workers: bool = True,
worker_seeding: str = 'all',
) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]:
"""Create a data loader with dynamic sequence length sampling for training.
Args:
dataset: Dataset to load from.
patch_size: Single patch size to use.
patch_size_choices: List of patch sizes for variable patch size training.
patch_size_choice_probs: Probabilities for each patch size choice.
train_seq_lens: Training sequence lengths for dynamic batching.
max_seq_len: Fixed sequence length for validation.
batch_size: Batch size for validation and max training sequence length.
is_training: Whether this is for training (enables dynamic batching).
mixup_fn: Optional mixup function.
no_aug: Disable augmentation.
re_prob: Random erasing probability.
re_mode: Random erasing mode.
re_count: Maximum number of erasing rectangles.
re_split: Random erasing split flag.
train_crop_mode: Training crop mode.
scale: Scale range for random resize crop.
ratio: Aspect ratio range for random resize crop.
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Color jitter factor.
color_jitter_prob: Color jitter probability.
grayscale_prob: Grayscale conversion probability.
gaussian_blur_prob: Gaussian blur probability.
auto_augment: AutoAugment policy.
num_aug_repeats: Number of augmentation repeats.
num_aug_splits: Number of augmentation splits.
interpolation: Interpolation method.
mean: Normalization mean values.
std: Normalization standard deviation values.
crop_pct: Crop percentage for validation.
crop_mode: Crop mode.
crop_border_pixels: Crop border pixels.
num_workers: Number of data loading workers.
distributed: Whether using distributed training.
rank: Process rank for distributed training.
world_size: Total number of processes.
seed: Random seed.
epoch: Starting epoch.
use_prefetcher: Whether to use prefetching.
pin_memory: Whether to pin memory.
img_dtype: Image data type.
device: Device to move tensors to.
persistent_workers: Whether to use persistent workers.
worker_seeding: Worker seeding mode.
Returns:
DataLoader or NaFlexPrefetchLoader instance.
"""
if is_training:
# For training, use the dynamic sequence length mechanism
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
transform_factory = partial(
create_transform,
is_training=True,
no_aug=no_aug,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
use_prefetcher=use_prefetcher,
naflex=True,
)
max_train_seq_len = max(train_seq_lens)
max_tokens_per_batch = batch_size * max_train_seq_len
if isinstance(dataset, torch.utils.data.IterableDataset):
assert False, "IterableDataset Wrapper is a WIP"
naflex_dataset = NaFlexMapDatasetWrapper(
dataset,
transform_factory=transform_factory,
patch_size=patch_size,
patch_size_choices=patch_size_choices,
patch_size_choice_probs=patch_size_choice_probs,
seq_lens=train_seq_lens,
max_tokens_per_batch=max_tokens_per_batch,
mixup_fn=mixup_fn,
seed=seed,
distributed=distributed,
rank=rank,
world_size=world_size,
shuffle=True,
epoch=epoch,
)
# NOTE: Collation is handled by the dataset wrapper for training
loader = torch.utils.data.DataLoader(
naflex_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
sampler=None,
pin_memory=pin_memory,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
)
else:
# For validation, use fixed sequence length (unchanged)
dataset.transform = create_transform(
is_training=False,
interpolation=interpolation,
mean=mean,
std=std,
# FIXME add crop args when sequence transforms support crop modes
use_prefetcher=use_prefetcher,
naflex=True,
patch_size=patch_size,
max_seq_len=max_seq_len,
patchify=True,
)
# Create the collator
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
# Handle distributed training
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
# For validation, use OrderedDistributedSampler
from timm.data.distributed_sampler import OrderedDistributedSampler
sampler = OrderedDistributedSampler(dataset)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=False,
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
)
return loader