timm/data/naflex_mixup.py (147 lines of code) (raw):
"""Variable‑size Mixup / CutMix utilities for NaFlex data loaders.
This module provides:
* `mix_batch_variable_size` – pixel‑level Mixup/CutMix that operates on a
list of images whose spatial sizes differ, mixing only their central overlap
so no resizing is required.
* `pairwise_mixup_target` – builds soft‑label targets that exactly match the
per‑sample pixel provenance produced by the mixer.
* `NaFlexMixup` – a callable functor that wraps the two helpers and stores
all augmentation hyper‑parameters in one place, making it easy to plug into
different dataset wrappers.
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
"""
import math
import random
from typing import Dict, List, Tuple, Union
import torch
def mix_batch_variable_size(
imgs: List[torch.Tensor],
*,
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
local_shuffle: int = 4,
) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]:
"""Apply Mixup or CutMix on a batch of variable-sized images.
Sorts images by aspect ratio and pairs neighboring samples. Only the mutual
central overlap region of each pair is mixed.
Args:
imgs: List of transformed images shaped (C, H, W).
mixup_alpha: Beta distribution alpha for Mixup. Set to 0 to disable.
cutmix_alpha: Beta distribution alpha for CutMix. Set to 0 to disable.
switch_prob: Probability of using CutMix when both modes are enabled.
local_shuffle: Size of local windows for shuffling after aspect sorting.
Returns:
Tuple of (mixed_imgs, lam_list, pair_to) where:
- mixed_imgs: List of mixed images
- lam_list: Per-sample lambda values representing mixing degree
- pair_to: Mapping i -> j of which sample was mixed with which
"""
if len(imgs) < 2:
raise ValueError("Need at least two images to perform Mixup/CutMix.")
# Decide augmentation mode and raw λ
if mixup_alpha > 0.0 and cutmix_alpha > 0.0:
use_cutmix = torch.rand(()).item() < switch_prob
alpha = cutmix_alpha if use_cutmix else mixup_alpha
elif mixup_alpha > 0.0:
use_cutmix = False
alpha = mixup_alpha
elif cutmix_alpha > 0.0:
use_cutmix = True
alpha = cutmix_alpha
else:
raise ValueError("Both mixup_alpha and cutmix_alpha are zero – nothing to do.")
lam_raw = torch.distributions.Beta(alpha, alpha).sample().item()
lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety
# Pair images by nearest aspect ratio
order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
if local_shuffle > 1:
for start in range(0, len(order), local_shuffle):
random.shuffle(order[start:start + local_shuffle])
pair_to: Dict[int, int] = {}
for a, b in zip(order[::2], order[1::2]):
pair_to[a] = b
pair_to[b] = a
odd_one = order[-1] if len(imgs) % 2 else None
mixed_imgs: List[torch.Tensor] = [None] * len(imgs)
lam_list: List[float] = [1.0] * len(imgs)
for i in range(len(imgs)):
if i == odd_one:
mixed_imgs[i] = imgs[i]
continue
j = pair_to[i]
xi, xj = imgs[i], imgs[j]
_, hi, wi = xi.shape
_, hj, wj = xj.shape
dest_area = hi * wi
# Central overlap common to both images
oh, ow = min(hi, hj), min(wi, wj)
overlap_area = oh * ow
top_i, left_i = (hi - oh) // 2, (wi - ow) // 2
top_j, left_j = (hj - oh) // 2, (wj - ow) // 2
xi = xi.clone()
if use_cutmix:
# CutMix: random rectangle inside the overlap
cut_ratio = math.sqrt(1.0 - lam_raw)
ch, cw = int(oh * cut_ratio), int(ow * cut_ratio)
cut_area = ch * cw
y_off = random.randint(0, oh - ch)
x_off = random.randint(0, ow - cw)
yl_i, xl_i = top_i + y_off, left_i + x_off
yl_j, xl_j = top_j + y_off, left_j + x_off
xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw]
mixed_imgs[i] = xi
corrected_lam = 1.0 - cut_area / float(dest_area)
lam_list[i] = corrected_lam
else:
# Mixup: blend the entire overlap region
patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow]
blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended
mixed_imgs[i] = xi
corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
lam_list[i] = corrected_lam
return mixed_imgs, lam_list, pair_to
def smoothed_sparse_target(
targets: torch.Tensor,
*,
num_classes: int,
smoothing: float = 0.0,
) -> torch.Tensor:
off_val = smoothing / num_classes
on_val = 1.0 - smoothing + off_val
y_onehot = torch.full(
(targets.size(0), num_classes),
off_val,
dtype=torch.float32,
device=targets.device
)
y_onehot.scatter_(1, targets.unsqueeze(1), on_val)
return y_onehot
def pairwise_mixup_target(
targets: torch.Tensor,
pair_to: Dict[int, int],
lam_list: List[float],
*,
num_classes: int,
smoothing: float = 0.0,
) -> torch.Tensor:
"""Create soft targets that match the pixel‑level mixing performed.
Args:
targets: (B,) tensor of integer class indices.
pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
lam_list: Per‑sample fractions of own pixels, also from the mixer.
num_classes: Total number of classes in the dataset.
smoothing: Label‑smoothing value in the range [0, 1).
Returns:
Tensor of shape (B, num_classes) whose rows sum to 1.
"""
y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing)
targets = y_onehot.clone()
for i, j in pair_to.items():
lam = lam_list[i]
targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam)
return targets
class NaFlexMixup:
"""Callable wrapper that combines mixing and target generation."""
def __init__(
self,
*,
num_classes: int,
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
switch_prob: float = 0.5,
prob: float = 1.0,
local_shuffle: int = 4,
label_smoothing: float = 0.0,
) -> None:
"""Configure the augmentation.
Args:
num_classes: Total number of classes.
mixup_alpha: Beta α for Mixup. 0 disables Mixup.
cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
switch_prob: Probability of selecting CutMix when both modes are enabled.
prob: Probability of applying any mixing per batch.
local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
smoothing: Label‑smoothing value. 0 disables smoothing.
"""
self.num_classes = num_classes
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.switch_prob = switch_prob
self.prob = prob
self.local_shuffle = local_shuffle
self.smoothing = label_smoothing
def __call__(
self,
imgs: List[torch.Tensor],
targets: torch.Tensor,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Apply the augmentation and generate matching targets.
Args:
imgs: List of already transformed images shaped (C, H, W).
targets: Hard labels with shape (B,).
Returns:
mixed_imgs: List of mixed images in the same order and shapes as the input.
targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
"""
if not isinstance(targets, torch.Tensor):
targets = torch.tensor(targets)
if random.random() > self.prob:
targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing)
return imgs, targets.unbind(0)
mixed_imgs, lam_list, pair_to = mix_batch_variable_size(
imgs,
mixup_alpha=self.mixup_alpha,
cutmix_alpha=self.cutmix_alpha,
switch_prob=self.switch_prob,
local_shuffle=self.local_shuffle,
)
targets = pairwise_mixup_target(
targets,
pair_to,
lam_list,
num_classes=self.num_classes,
smoothing=self.smoothing,
)
return mixed_imgs, targets.unbind(0)