timm/data/naflex_loader.py (259 lines of code) (raw):
"""NaFlex data loader for dynamic sequence length training.
This module provides a specialized data loader for Vision Transformer models that supports:
- Dynamic sequence length sampling during training for improved efficiency
- Variable patch size training with probabilistic selection
- Patch-level random erasing augmentation
- Efficient GPU prefetching with normalization
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
"""
import math
from contextlib import suppress
from functools import partial
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .loader import _worker_init, adapt_to_chs
from .naflex_dataset import NaFlexMapDatasetWrapper, NaFlexCollator
from .naflex_random_erasing import PatchRandomErasing
from .transforms_factory import create_transform
class NaFlexPrefetchLoader:
"""Data prefetcher for NaFlex format which normalizes patches."""
def __init__(
self,
loader: torch.utils.data.DataLoader,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
channels: int = 3,
device: torch.device = torch.device('cuda'),
img_dtype: Optional[torch.dtype] = None,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
) -> None:
"""Initialize NaFlexPrefetchLoader.
Args:
loader: DataLoader to prefetch from.
mean: Mean values for normalization.
std: Standard deviation values for normalization.
channels: Number of image channels.
device: Device to move tensors to.
img_dtype: Data type for image tensors.
re_prob: Random erasing probability.
re_mode: Random erasing mode.
re_count: Maximum number of erasing rectangles.
re_num_splits: Number of augmentation splits.
"""
self.loader = loader
self.device = device
self.img_dtype = img_dtype or torch.float32
# Create mean/std tensors for normalization (will be applied to patches)
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, 1, channels)
self.channels = channels
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
self.std = torch.tensor(
[x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
if re_prob > 0.:
self.random_erasing = PatchRandomErasing(
erase_prob=re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device=device,
)
else:
self.random_erasing = None
# Check for CUDA/NPU availability
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
"""Iterate through the loader with prefetching and normalization.
Yields:
Tuple of (input_dict, targets) with normalized patches.
"""
first = True
if self.is_cuda:
stream = torch.cuda.Stream(device=self.device)
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream(device=self.device)
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input_dict, next_target in self.loader:
with stream_context():
# Move all tensors in input_dict to device
for k, v in next_input_dict.items():
if isinstance(v, torch.Tensor):
dtype = self.img_dtype if k == 'patches' else None
next_input_dict[k] = next_input_dict[k].to(
device=self.device,
non_blocking=True,
dtype=dtype,
)
next_target = next_target.to(device=self.device, non_blocking=True)
# Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats
patches_tensor = next_input_dict['patches']
original_shape = patches_tensor.shape
if patches_tensor.ndim == 3:
# Format: [B, N, P*P*C] - flattened patches
batch_size, num_patches, patch_pixels = original_shape
# To [B*N, P*P, C] for normalization and erasing
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
elif patches_tensor.ndim == 5:
# Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode)
batch_size, num_patches, patch_h, patch_w, channels = original_shape
assert channels == self.channels, f"Expected {self.channels} channels, got {channels}"
# To [B*N, Ph*Pw, C] for normalization and erasing
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
else:
raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.")
# Apply normalization
patches = patches.sub(self.mean).div(self.std)
if self.random_erasing is not None:
patches = self.random_erasing(
patches,
patch_coord=next_input_dict['patch_coord'],
patch_valid=next_input_dict.get('patch_valid', None),
)
# Reshape back to original format
next_input_dict['patches'] = patches.view(original_shape)
if not first:
yield input_dict, target
else:
first = False
if stream is not None:
if self.is_cuda:
torch.cuda.current_stream(device=self.device).wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream(device=self.device).wait_stream(stream)
input_dict = next_input_dict
target = next_target
yield input_dict, target
def __len__(self) -> int:
"""Get length of underlying loader.
Returns:
Number of batches in the loader.
"""
return len(self.loader)
@property
def sampler(self):
"""Get sampler from underlying loader.
Returns:
Sampler from the underlying DataLoader.
"""
return self.loader.sampler
@property
def dataset(self):
"""Get dataset from underlying loader.
Returns:
Dataset from the underlying DataLoader.
"""
return self.loader.dataset
def create_naflex_loader(
dataset,
patch_size: Optional[Union[Tuple[int, int], int]] = None,
patch_size_choices: Optional[List[int]] = None,
patch_size_choice_probs: Optional[List[float]] = None,
train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
max_seq_len: int = 576,
batch_size: int = 32,
is_training: bool = False,
mixup_fn: Optional[Callable] = None,
no_aug: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_split: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: float = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
num_aug_repeats: int = 0,
num_aug_splits: int = 0,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
num_workers: int = 4,
distributed: bool = False,
rank: int = 0,
world_size: int = 1,
seed: int = 42,
epoch: int = 0,
use_prefetcher: bool = True,
pin_memory: bool = True,
img_dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = torch.device('cuda'),
persistent_workers: bool = True,
worker_seeding: str = 'all',
) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]:
"""Create a data loader with dynamic sequence length sampling for training.
Args:
dataset: Dataset to load from.
patch_size: Single patch size to use.
patch_size_choices: List of patch sizes for variable patch size training.
patch_size_choice_probs: Probabilities for each patch size choice.
train_seq_lens: Training sequence lengths for dynamic batching.
max_seq_len: Fixed sequence length for validation.
batch_size: Batch size for validation and max training sequence length.
is_training: Whether this is for training (enables dynamic batching).
mixup_fn: Optional mixup function.
no_aug: Disable augmentation.
re_prob: Random erasing probability.
re_mode: Random erasing mode.
re_count: Maximum number of erasing rectangles.
re_split: Random erasing split flag.
train_crop_mode: Training crop mode.
scale: Scale range for random resize crop.
ratio: Aspect ratio range for random resize crop.
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Color jitter factor.
color_jitter_prob: Color jitter probability.
grayscale_prob: Grayscale conversion probability.
gaussian_blur_prob: Gaussian blur probability.
auto_augment: AutoAugment policy.
num_aug_repeats: Number of augmentation repeats.
num_aug_splits: Number of augmentation splits.
interpolation: Interpolation method.
mean: Normalization mean values.
std: Normalization standard deviation values.
crop_pct: Crop percentage for validation.
crop_mode: Crop mode.
crop_border_pixels: Crop border pixels.
num_workers: Number of data loading workers.
distributed: Whether using distributed training.
rank: Process rank for distributed training.
world_size: Total number of processes.
seed: Random seed.
epoch: Starting epoch.
use_prefetcher: Whether to use prefetching.
pin_memory: Whether to pin memory.
img_dtype: Image data type.
device: Device to move tensors to.
persistent_workers: Whether to use persistent workers.
worker_seeding: Worker seeding mode.
Returns:
DataLoader or NaFlexPrefetchLoader instance.
"""
if is_training:
# For training, use the dynamic sequence length mechanism
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
transform_factory = partial(
create_transform,
is_training=True,
no_aug=no_aug,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
use_prefetcher=use_prefetcher,
naflex=True,
)
max_train_seq_len = max(train_seq_lens)
max_tokens_per_batch = batch_size * max_train_seq_len
if isinstance(dataset, torch.utils.data.IterableDataset):
assert False, "IterableDataset Wrapper is a WIP"
naflex_dataset = NaFlexMapDatasetWrapper(
dataset,
transform_factory=transform_factory,
patch_size=patch_size,
patch_size_choices=patch_size_choices,
patch_size_choice_probs=patch_size_choice_probs,
seq_lens=train_seq_lens,
max_tokens_per_batch=max_tokens_per_batch,
mixup_fn=mixup_fn,
seed=seed,
distributed=distributed,
rank=rank,
world_size=world_size,
shuffle=True,
epoch=epoch,
)
# NOTE: Collation is handled by the dataset wrapper for training
loader = torch.utils.data.DataLoader(
naflex_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
sampler=None,
pin_memory=pin_memory,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
)
else:
# For validation, use fixed sequence length (unchanged)
dataset.transform = create_transform(
is_training=False,
interpolation=interpolation,
mean=mean,
std=std,
# FIXME add crop args when sequence transforms support crop modes
use_prefetcher=use_prefetcher,
naflex=True,
patch_size=patch_size,
max_seq_len=max_seq_len,
patchify=True,
)
# Create the collator
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
# Handle distributed training
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
# For validation, use OrderedDistributedSampler
from timm.data.distributed_sampler import OrderedDistributedSampler
sampler = OrderedDistributedSampler(dataset)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=False,
)
if use_prefetcher:
loader = NaFlexPrefetchLoader(
loader,
mean=mean,
std=std,
img_dtype=img_dtype,
device=device,
)
return loader