timm/models/naflexvit.py (1,076 lines of code) (raw):
""" NaFlex Vision Transformer
An improved version of the Vision Transformer with:
1. Encapsulated embedding and position encoding in a single module
2. Support for linear patch embedding on pre-patchified inputs
3. Support for NaFlex variable aspect, variable resolution
4. Support for FlexiViT variable patch size
5. Support for NaViT fractional/factorized position embedding
Based on ideas from:
- Original Vision Transformer: https://arxiv.org/abs/2010.11929
- FlexiViT: https://arxiv.org/abs/2212.08013
- NaViT: https://arxiv.org/abs/2307.06304
- NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
"""
import logging
import math
from dataclasses import dataclass, fields, replace
from functools import partial
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import (
AttentionPoolLatent,
Mlp,
to_2tuple,
get_act_layer,
get_norm_layer,
LayerNorm,
_assert,
)
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._features_fx import register_notrace_function, register_notrace_module
from timm.models._registry import register_model, generate_default_cfgs
from timm.models._manipulate import checkpoint, checkpoint_seq, named_apply
from .vision_transformer import Block, global_pool_nlc
__all__ = ['NaFlexVitCfg', 'NaFlexVit']
_logger = logging.getLogger(__name__)
@dataclass
class NaFlexVitCfg:
"""Configuration for FlexVit model.
This dataclass contains the bulk of model configuration parameters,
with core parameters (img_size, in_chans, num_classes, etc.) remaining
as direct constructor arguments for API compatibility.
"""
# Architecture parameters
patch_size: Union[int, Tuple[int, int]] = 16
embed_dim: int = 768
depth: int = 12
num_heads: int = 12
mlp_ratio: float = 4.0
# Attention parameters
qkv_bias: bool = True
qk_norm: bool = False
proj_bias: bool = True
attn_drop_rate: float = 0.0
# Regularization
init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None)
drop_rate: float = 0.0 # Dropout rate for classifier
pos_drop_rate: float = 0.0 # Dropout rate for position embeddings
patch_drop_rate: float = 0.0 # Dropout rate for patch tokens
proj_drop_rate: float = 0.0 # Dropout rate for linear projections
drop_path_rate: float = 0.0 # Stochastic depth drop rate
# Prefix token configuration
class_token: bool = False # Use class token
reg_tokens: int = 0 # Number of register tokens
# Position embedding configuration
pos_embed: str = 'learned' # Type of position embedding ('learned', 'factorized', 'rope', 'none')
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization
pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing
pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation
pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation
# Image processing
dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
# Architecture choices
pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling)
global_pool: str = 'map' # Type of global pooling for final sequence
pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling
# Weight initialization
weight_init: str = '' # Weight initialization scheme
fix_init: bool = True # Apply weight initialization fix (scaling w/ layer index)
# Embedding configuration
embed_proj_type: str = 'linear' # Type of embedding layer ('conv' or 'linear')
input_norm_layer: Optional[str] = None # Normalization layer for embeddings input (before input projection)
embed_norm_layer: Optional[str] = None # Normalization layer for embeddings (after input projection)
# Layer implementations
norm_layer: Optional[str] = None # Normalization layer for transformer blocks
act_layer: Optional[str] = None # Activation layer for MLP blocks
block_fn: Optional[str] = None # Transformer block implementation class name
mlp_layer: Optional[str] = None # MLP implementation class name
# Variable patch size support
enable_patch_interpolator: bool = False # Enable dynamic patch size support
def _overlay_kwargs(cfg: NaFlexVitCfg, **kwargs) -> NaFlexVitCfg:
"""Overlay kwargs onto config, replacing config values with provided kwargs."""
# Only update fields that exist in the config
config_fields = set(cfg.__dataclass_fields__.keys())
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
if config_kwargs:
cfg = replace(cfg, **config_kwargs)
return cfg
def batch_patchify(
x: torch.Tensor,
patch_size: Tuple[int, int],
pad: bool = True,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Patchify a batch of images.
Args:
x: Input tensor of shape [B, C, H, W].
patch_size: Patch dimensions (patch_h, patch_w).
pad: Whether to pad images to be divisible by patch size.
Returns:
Tuple of (patches, grid_size) where patches has shape [B, N, P*P*C]
and grid_size is (num_patches_h, num_patches_w).
"""
B, C, H, W = x.shape
ph, pw = patch_size
# Ensure the image is divisible by patch size
if pad and (H % ph != 0 or W % pw != 0):
pad_h = (ph - H % ph) % ph
pad_w = (pw - W % pw) % pw
x = F.pad(x, (0, pad_w, 0, pad_h))
nh, nw = H // ph, W // pw
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
return patches, (nh, nw)
def calculate_naflex_grid_sizes(_coord: torch.Tensor):
# Calculate the appropriate grid size from coords
max_y = _coord[:, :, 0].amax(dim=1) + 1
max_x = _coord[:, :, 1].amax(dim=1) + 1
return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
@register_notrace_module
class NaFlexEmbeds(nn.Module):
"""NaFlex Embedding module for Vision Transformers.
This module encapsulates the complete embedding process for Vision Transformers,
supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
1. Patch embedding (via Conv2d or Linear)
2. Class and register token preparation
3. Position embedding addition with interpolation support
4. Pre-normalization (if requested)
5. Dropout application
NaFlex capabilities include:
- Variable aspect ratio and resolution via patch coordinates
- Patch type indicators for handling padding tokens in attention
- Flexible position embedding interpolation for arbitrary grid sizes
- Support for factorized position embeddings
The patch embedding can be one of two types:
- Conv2d-based (default): For standard image inputs [B, C, H, W]
- Linear-based: For pre-patchified inputs [B, N, P*P*C]
Args:
patch_size: Size of patches for patch embedding
in_chans: Number of input image channels
embed_dim: Dimensionality of patch embedding
proj_type: Type of embedding projection layer ('conv' or 'linear')
input_norm_layer: Normalization layer applied to input (linear mode only)
proj_norm_layer: Normalization layer applied after projection
pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none')
pos_drop_rate: Dropout rate for position embeddings
patch_drop_rate: Dropout rate for patch tokens
class_token: Whether to include a class token
reg_tokens: Number of register tokens to include
bias: Whether to use bias in projection layers
dynamic_img_pad: Whether to enable dynamic padding for variable resolution
pos_embed_grid_size: Grid size for position embedding initialization
pos_embed_interp_mode: Interpolation mode for position embedding resizing
pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
default_img_size: Default image size for position embedding grid calculation
"""
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
proj_type: Optional[str] = None,
proj_bias: bool = True,
class_token: bool = True,
reg_tokens: int = 0,
dynamic_img_pad: bool = False,
default_img_size: Optional[Union[int, Tuple[int, int]]] = None,
pos_embed: str = 'learned',
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
pos_embed_use_grid_sample: bool = False,
input_norm_layer: Optional[Type[nn.Module]] = None,
proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None,
norm_layer: Optional[Type[nn.Module]] = None,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
enable_patch_interpolator: bool = False,
) -> None:
"""Initialize NaFlexEmbeds module.
Args:
patch_size: Size of patches for patch embedding.
in_chans: Number of input image channels.
embed_dim: Dimensionality of patch embedding.
proj_type: Type of embedding projection layer ('conv' or 'linear').
proj_bias: Whether to use bias in projection layers.
class_token: Whether to include a class token.
reg_tokens: Number of register tokens to include.
dynamic_img_pad: Whether to enable dynamic padding for variable resolution.
default_img_size: Default image size for position embedding grid calculation.
pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none').
pos_embed_grid_size: Grid size for position embedding initialization.
pos_embed_interp_mode: Interpolation mode for position embedding resizing.
pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation.
input_norm_layer: Normalization layer applied to input (linear mode only).
proj_norm_layer: Normalization layer applied after projection.
norm_layer: Default normalization layer.
pos_drop_rate: Dropout rate for position embeddings.
patch_drop_rate: Dropout rate for patch tokens.
enable_patch_interpolator: Enable dynamic patch size support.
"""
super().__init__()
self.has_class_token = class_token
self.num_reg_tokens = reg_tokens
self.pos_embed_interp_mode = pos_embed_interp_mode
self.pos_embed_ar_preserving = pos_embed_ar_preserving
self.pos_embed_use_grid_sample = pos_embed_use_grid_sample
self.patch_size = to_2tuple(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
self.dynamic_img_pad = dynamic_img_pad
self.enable_patch_interpolator = enable_patch_interpolator
# Calculate number of prefix tokens
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
# Create class and register tokens
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
# Calculate grid size and number of patches
self.default_img_size: Optional[Tuple[int, int]] = None
self.pos_embed_grid_size: Optional[Tuple[int, int]] = None # Grid size used for learned pos embed init
if pos_embed_grid_size is not None:
# Highest priority, use provided pos_embed_grid_size
self.pos_embed_grid_size = pos_embed_grid_size
elif default_img_size is not None:
# Fallback to calculating grid size from img_size + patch_size if img size provided.
self.default_img_size = to_2tuple(default_img_size)
self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)])
# Determine patch embedding type (linear or conv2d)
if proj_type == 'linear':
# Create linear projection for pre-patchified inputs
# Input dimension is patch_size^2 * in_chans
patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans
assert not (input_norm_layer is True and norm_layer is None), \
"`norm_layer` must be given when input_norm_layer=True"
input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None)
self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None
self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias)
self.flatten = False
self.is_linear = True
else:
# Default to convolutional patch embedding for image inputs
assert not input_norm_layer
self.norm_input = None
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=proj_bias
)
self.flatten = True
self.is_linear = False
# Create patch embedding interpolator if enabled
if self.enable_patch_interpolator:
from timm.layers import PatchEmbedInterpolator
self.patch_interpolator = PatchEmbedInterpolator(
base_patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
interpolation=pos_embed_interp_mode,
antialias=True,
)
else:
self.patch_interpolator = None
# Create normalization layer after the projection
assert not (proj_norm_layer is True and norm_layer is None), \
"`norm_layer` must be given when proj_norm_layer=True"
proj_norm_layer = norm_layer if proj_norm_layer is True else (proj_norm_layer or None)
self.norm = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
# Create position embedding if needed - only for patches, never for prefix tokens
if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
raise ValueError(
"Cannot initialize position embeddings without grid_size."
"Please provide img_size or pos_embed_grid_size.")
self.pos_embed: Optional[torch.Tensor] = None
self.pos_embed_y: Optional[torch.Tensor] = None
self.pos_embed_x: Optional[torch.Tensor] = None
if not pos_embed or pos_embed == 'none':
self.pos_embed_type = 'none'
elif pos_embed == 'rope':
self.pos_embed_type = 'rope'
# Rotary embeddings will be computed on-the-fly in the forward pass
elif pos_embed == 'factorized':
assert self.pos_embed_grid_size is not None
h, w = self.pos_embed_grid_size
self.pos_embed_type = 'factorized'
self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02)
self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02)
else:
assert self.pos_embed_grid_size is not None
h, w = self.pos_embed_grid_size
self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02)
self.pos_embed_type = 'learned'
# Dropout layers
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
from timm.layers.patch_dropout import PatchDropout
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
def feature_info(self, location) -> Dict[str, Any]:
"""Get feature information for feature extraction.
Args:
location: Feature extraction location identifier
Returns:
Dictionary containing feature channel count and reduction factor
"""
return dict(num_chs=self.embed_dim, reduction=self.patch_size)
def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
"""Get the feature reduction ratio (stride) of the patch embedding.
Args:
as_scalar: Whether to return the maximum dimension as a scalar
Returns:
Feature reduction ratio as scalar or tuple
"""
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
"""Calculate grid (feature) size for given image size.
Takes into account dynamic padding when enabled.
Args:
img_size: Input image size as (height, width)
Returns:
Grid size as (grid_height, grid_width)
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
#@torch.compiler.disable()
def _apply_learned_naflex_pos_embed(
self,
x: torch.Tensor,
patch_coord: torch.Tensor,
) -> None:
"""Apply learned position embeddings to NaFlex batch in-place.
Interpolates learned 2D position embeddings for each sample in the batch
based on their individual grid sizes.
Args:
x: Input tensor to add position embeddings to [B, N, C]
patch_coord: Patch coordinates [B, N, 2] with (y, x) values
"""
# Calculate grid sizes from patch coordinates
naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
orig_h, orig_w = self.pos_embed.shape[1:3]
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
def _interp2d(size):
"""
Return a flattened positional-embedding grid at an arbitrary spatial resolution.
Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
a (1, H*W, C) sequence that matches the requested size.
"""
if (size[0] == orig_h) and (size[1] == orig_w):
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
_interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
pos_embed_flat = F.interpolate(
pos_embed_nchw,
size=_interp_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
)[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
return pos_embed_flat.to(dtype=x.dtype)
# Determine unique grid sizes to avoid duplicate interpolation
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
# k = h << 16 | w # FIXME can get jit compat with this
size_to_indices.setdefault(k, []).append(bi)
for k, batch_indices in size_to_indices.items():
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
pos_embed_flat = _interp2d(k)
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
)
def _apply_learned_naflex_pos_embed_grid_sample(
self,
x: torch.Tensor,
patch_coord: torch.Tensor,
) -> None:
"""Apply learned position embeddings to NaFlex batch using grid_sample.
Uses F.grid_sample for efficient interpolation of learned 2D position embeddings
based on patch coordinates. Based on proposal by https://github.com/stas-sl
Args:
x: Input tensor to add position embeddings to [B, N, C]
patch_coord: Patch coordinates [B, N, 2] with (y, x) values
"""
device = x.device
B, N, C = x.shape
shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i]
if self.pos_embed_ar_preserving:
L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
L_global = L_i.amax()
grid_size_y = grid_size_x = L_global
scale_x = scale_y = L_global / L_i # uniform zoom (B,)
else:
grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,)
scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32)
theta[:, 0, 0] = scale_x
theta[:, 1, 1] = scale_y
theta[:, 0, 2] = scale_x - 1 # translate x
theta[:, 1, 2] = scale_y - 1 # translate y
grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False)
pos_embed = F.grid_sample(
self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(),
grid,
mode=self.pos_embed_interp_mode,
align_corners=False,
padding_mode='border',
).to(dtype=x.dtype) # (B, C, H_out, W_out)
bi = torch.arange(B, device=device).unsqueeze(1)
x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+='
def _apply_learned_pos_embed(
self,
x: torch.Tensor,
grid_size: List[int],
) -> None:
"""Apply learned position embeddings to standard 2D batch in-place.
Interpolates learned 2D position embeddings to match the specified grid size.
Args:
x: Input tensor to add position embeddings to [B, H*W, C]
grid_size: Target grid size as [height, width]
"""
orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] == orig_h and grid_size[1] == orig_w:
# No resize needed, just flatten
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
# Resize if needed - directly using F.interpolate
if self.pos_embed_ar_preserving:
L = max(grid_size)
_interp_size = L, L
else:
_interp_size = grid_size
pos_embed_flat = F.interpolate(
self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
size=_interp_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
)[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
x.add_(pos_embed_flat)
def _apply_factorized_naflex_pos_embed(
self,
x: torch.Tensor,
patch_coord: torch.Tensor,
) -> None:
"""Apply factorized position embeddings to NaFlex batch in-place.
Uses separate Y and X position embedding tables that are interpolated
and combined for each sample's grid size.
Args:
x: Input tensor to add position embeddings to [B, N, C]
patch_coord: Patch coordinates [B, N, 2] with (y, x) values
"""
# Calculate grid sizes from patch coordinates
naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
# bucket samples that share the same (H, W) so we build each grid once
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
size_to_indices.setdefault(k, []).append(bi)
def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
"""
Resample a 1-D positional-embedding table to specified length
and return it in (1, L, C) layout, dtype matching x.
"""
if new_length == orig_length:
return table.to(dtype=x.dtype)
return F.interpolate(
table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
size=new_length,
mode='linear',
align_corners=False,
).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
for k, batch_indices in size_to_indices.items():
target_h, target_w = k
if self.pos_embed_ar_preserving:
len_y = len_x = max(target_h, target_w)
else:
len_y, len_x = target_h, target_w
pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
# Broadcast, add and flatten to sequence layout (row major)
pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
pos = pos.flatten(1, 2)
seq_len = min(x.shape[1], pos.shape[1])
x[:, :seq_len].index_add_(
0,
torch.as_tensor(batch_indices, device=x.device),
pos[:, :seq_len].expand(len(batch_indices), -1, -1)
)
def _apply_factorized_naflex_pos_embed_grid_sample(
self,
x: torch.Tensor,
patch_coord: torch.Tensor,
) -> None:
"""Apply factorized position embeddings to NaFlex batch using grid_sample.
Uses F.grid_sample for efficient interpolation of separate Y and X position
embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl
Args:
x: Input tensor to add position embeddings to [B, N, C]
patch_coord: Patch coordinates [B, N, 2] with (y, x) values
"""
device = x.device
B, _, C = x.shape
shapes = patch_coord.amax(dim=1) + 1
if self.pos_embed_ar_preserving:
# Aspect ratio preserving mode: use square grid with uniform scaling
L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
L_global = L_i.amax()
grid_size_y = grid_size_x = L_global
scale_x = scale_y = L_global / L_i # uniform zoom (B,)
else:
# Standard mode: different scaling for x and y
grid_size_y, grid_size_x = shapes.amax(0)
scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor:
pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L)
theta = torch.zeros(B, 2, 3, device=x.device)
theta[:, 0, 0] = scale
theta[:, 0, 2] = scale - 1
theta[:, 1, 1] = 1
grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False)
pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border')
return pe.to(x.dtype)
# Interpolate along each axis
pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x)
pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y)
bi = torch.arange(B, device=device).unsqueeze(1)
x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]]
def _apply_factorized_pos_embed(
self,
x: torch.Tensor,
grid_size: List[int],
) -> None:
"""Apply factorized position embeddings to standard 2D batch in-place.
Uses separate Y and X position embedding tables that are interpolated
and combined for the specified grid size.
Args:
x: Input tensor to add position embeddings to [B, H*W, C]
grid_size: Target grid size as [height, width]
"""
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
target_h, target_w = grid_size
if self.pos_embed_ar_preserving:
len_y = len_x = max(target_h, target_w)
else:
len_y, len_x = target_h, target_w
def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
if new_length == orig_length:
return table.to(dtype=x.dtype)
return F.interpolate(
table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L)
size=new_length,
mode='linear',
align_corners=False,
).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C)
# Interpolate embeddings
pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
# Broadcast, add and flatten to sequence layout (row major)
pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C)
pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C)
x.add_(pos_embed_flat)
def forward(
self,
x: torch.Tensor,
patch_coord: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for patch embedding with position encoding.
Args:
x: Input tensor. Supported formats:
- [B, C, H, W] for conv mode
- [B, N, P*P*C] for pre-patchified linear mode (normal)
- [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size)
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
Returns:
Embedded tensor with position encoding and class/register tokens.
Shape: [B, num_prefix_tokens + N, embed_dim]
"""
grid_size: Optional[List[int]] = None
B = x.shape[0]
if self.is_linear:
# Linear embedding path, works with NaFlex mode or standard 2D mode
if patch_coord is None:
# Standard 2D (B, C, H, W) mode
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
else:
# Pre-patchified NaFlex mode
# Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C]
_assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.')
# Handle variable patch size projection
if self.enable_patch_interpolator and x.ndim == 5:
_assert(self.norm_input is None, 'input norm not supported with patch resizing')
# Apply projection with interpolation
x = self.patch_interpolator(
x,
self.proj.weight,
self.proj.bias,
patch_size=tuple(x.shape[2:4]), # patch size from [B, N, Ph, Pw, C] shape
is_linear=True,
)
else:
# Standard projection
x = x.flatten(2) # ensure [B, N, P*P*C], flatten Ph*Pw*C if separate
if self.norm_input is not None:
x = self.norm_input(x)
x = self.proj(x)
else:
_assert(x.ndim == 4, 'Convolutional input must be 4D')
if self.dynamic_img_pad:
H, W = x.shape[-2:]
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
# Apply normalization after flattening
x = self.norm(x)
if self.pos_embed_type == 'learned':
if grid_size is not None:
# Standard 2D mode
self._apply_learned_pos_embed(x, grid_size=grid_size)
else:
# NaFlex mode
if self.pos_embed_use_grid_sample:
self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
else:
self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord)
elif self.pos_embed_type == 'factorized':
if grid_size is not None:
# Standard 2D mode
self._apply_factorized_pos_embed(x, grid_size=grid_size)
else:
# NaFlex mode
if self.pos_embed_use_grid_sample:
self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
else:
self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord)
elif self.pos_embed_type == 'rope':
assert False, "ROPE not yet implemented"
# Prepare and add class and register tokens
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(B, -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(B, -1, -1))
# Add tokens to the beginning
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
# Apply dropouts
x = self.pos_drop(x)
x = self.patch_drop(x)
return x
@register_notrace_function
def create_attention_mask(
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
symmetric: bool = True,
q_len: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> Optional[torch.Tensor]:
"""Creates an attention mask from patch validity information.
Supports two modes controlled by `symmetric`:
1. `symmetric=True` (default): Creates a symmetric mask of shape
[B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
both token i and token j are valid. Suitable for standard self-attention.
2. `symmetric=False`: Creates a potentially non-square mask of shape
[B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
the key/value token k is valid. Query token validity is not checked
in the mask itself. Useful for cross-attention or specific self-attention
implementations `q_len` can be specified.
Used for NaFlex mode to handle variable token counts and padding tokens.
Args:
patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
num_prefix_tokens: Number of prefix tokens (class token, register tokens)
to prepend, which are always considered valid.
symmetric: If True, create a symmetric mask.
If False, create an expanded mask based only on key/value validity.
q_len: Query sequence length override. Only used when `symmetric` is False.
Defaults to the key/value sequence length (`kv_len`) if None.
dtype: Dtype of the output attention mask (e.g., torch.float32).
Returns:
Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
or [B, 1, q_len, kv_len] if symmetric=False.
"""
if patch_valid is None:
return None
patch_valid = patch_valid.bool() # Ensure boolean type
B, N = patch_valid.shape
kv_len = N # Initial key/value length is the number of patches
# Prepend prefix tokens if any
if num_prefix_tokens > 0:
# Create prefix validity tensor on the same device/dtype base as patch_valid
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
# Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
kv_len += num_prefix_tokens # Update total key/value sequence length
if symmetric:
# Symmetric mask is True where BOTH query and key are valid
mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
else:
# Expanded mask
q_len = q_len or kv_len
mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
# Create the float mask and apply masking using additive mask convention
mask_float = torch.zeros_like(mask_bool, dtype=dtype)
# Fill with negative infinity where mask_bool is False (masked positions)
mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
return mask_float
@register_notrace_function
def global_pool_naflex(
x: torch.Tensor,
patch_valid: Optional[torch.Tensor] = None,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
) -> torch.Tensor:
"""Global pooling with NaFlex support for masked tokens.
Applies global pooling while respecting patch validity masks to exclude
padding tokens from pooling operations.
Args:
x: Input tensor with shape [B, N, C]
patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
num_prefix_tokens: Number of prefix tokens (class/register)
reduce_include_prefix: Whether to include prefix tokens in pooling reduction
Returns:
Pooled tensor with shape [B, C]
"""
if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
# Fall back to standard pooling
x = global_pool_nlc(
x,
pool_type=pool_type,
num_prefix_tokens=num_prefix_tokens,
reduce_include_prefix=reduce_include_prefix,
)
return x
# For NaFlex mode, we need to apply masked pooling to exclude padding tokens
if num_prefix_tokens > 0:
if reduce_include_prefix:
# Include prefix tokens in pooling - they are always considered valid
# patch_valid only covers patch tokens, so create combined validity mask
prefix_valid = patch_valid.new_ones(x.shape[0], num_prefix_tokens)
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
else:
# Exclude prefix tokens from pooling (default behavior)
x = x[:, num_prefix_tokens:]
patch_valid_float = patch_valid.to(x.dtype)
if pool_type == 'avg':
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
pooled = masked_sums / valid_counts
return pooled
elif pool_type == 'avgmax':
# For avgmax, compute masked average and masked max
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
masked_avg = masked_sums / valid_counts
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_max = masked_x.amax(dim=1)
# Combine average and max
return 0.5 * (masked_avg + masked_max)
elif pool_type == 'max':
# For max pooling we set masked positions to large negative value
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
return masked_x.amax(dim=1)
else:
assert False
class NaFlexVit(nn.Module):
"""NaFlexVit: Vision Transformer with NaFlex support for flexible input handling.
A flexible implementation of Vision Transformer that supports:
- Standard image classification with various pooling strategies
- NaFlex functionality for variable aspect ratios and resolutions
- Linear patch embedding for pre-patchified inputs
- Multiple position embedding strategies (learned, factorized, rope)
- Comprehensive attention masking for efficient batch processing
- Encapsulated embedding and position encoding in FlexEmbeds module
- Compatible with standard ViT checkpoints through checkpoint filtering
"""
def __init__(
self,
cfg: Optional[NaFlexVitCfg] = None,
in_chans: int = 3,
num_classes: int = 1000,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
**kwargs,
) -> None:
"""Initialize NaFlexVit model.
Args:
cfg: Model configuration. If None, uses default NaFlexVitCfg.
in_chans: Number of input image channels.
num_classes: Number of classification classes.
img_size: Input image size (for backwards compatibility with classic vit).
**kwargs: Additional config parameters to override cfg values.
"""
super().__init__()
# Initialize config
cfg = cfg or NaFlexVitCfg()
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
# Validate configuration
assert cfg.global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert cfg.class_token or cfg.global_pool != 'token'
assert cfg.pos_embed in ('', 'none', 'learned', 'factorized')
# Resolve layer implementations
norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm
embed_norm_layer = get_norm_layer(cfg.embed_norm_layer)
act_layer = get_act_layer(cfg.act_layer) or nn.GELU
block_fn = cfg.block_fn or Block # TODO: Support configurable block_fn via string lookup
mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup
# Store instance variables
self.num_classes = num_classes
self.global_pool = cfg.global_pool
self.num_features = self.head_hidden_size = self.embed_dim = cfg.embed_dim # for consistency with other models
self.num_prefix_tokens = 1 if cfg.class_token else 0
self.num_prefix_tokens += cfg.reg_tokens
self.num_reg_tokens = cfg.reg_tokens
self.has_class_token = cfg.class_token
self.pool_include_prefix = cfg.pool_include_prefix
self.grad_checkpointing = False
# Initialize embedding module (includes patch, position embedding, and class/reg tokens)
# FlexEmbeds is always used - handles both linear and conv embedding
self.embeds = NaFlexEmbeds(
patch_size=cfg.patch_size,
in_chans=in_chans,
embed_dim=cfg.embed_dim,
proj_type=cfg.embed_proj_type,
proj_bias=not cfg.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
class_token=cfg.class_token,
reg_tokens=cfg.reg_tokens,
default_img_size=img_size,
dynamic_img_pad=cfg.dynamic_img_pad,
pos_embed=cfg.pos_embed,
pos_embed_grid_size=cfg.pos_embed_grid_size,
pos_embed_interp_mode=cfg.pos_embed_interp_mode,
pos_embed_ar_preserving=cfg.pos_embed_ar_preserving,
pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample,
proj_norm_layer=embed_norm_layer,
pos_drop_rate=cfg.pos_drop_rate,
patch_drop_rate=cfg.patch_drop_rate,
enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False),
)
self.norm_pre = norm_layer(cfg.embed_dim) if cfg.pre_norm else nn.Identity()
# Transformer blocks
dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=cfg.embed_dim,
num_heads=cfg.num_heads,
mlp_ratio=cfg.mlp_ratio,
qkv_bias=cfg.qkv_bias,
qk_norm=cfg.qk_norm,
proj_bias=cfg.proj_bias,
init_values=cfg.init_values,
proj_drop=cfg.proj_drop_rate,
attn_drop=cfg.attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(cfg.depth)])
# Feature info for downstream tasks
patch_reduction = self.embeds.feat_ratio(as_scalar=True)
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=cfg.embed_dim, reduction=patch_reduction)
for i in range(cfg.depth)
]
self.norm = norm_layer(cfg.embed_dim) if cfg.final_norm and not cfg.fc_norm else nn.Identity()
# Classifier Head
if cfg.global_pool == 'map':
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=cfg.num_heads,
mlp_ratio=cfg.mlp_ratio,
norm_layer=norm_layer,
act_layer=act_layer,
)
else:
self.attn_pool = None
# Handle fc_norm default value
fc_norm = cfg.fc_norm
if fc_norm is None:
fc_norm = cfg.global_pool == 'avg'
self.fc_norm = norm_layer(cfg.embed_dim) if cfg.final_norm and fc_norm else nn.Identity()
self.head_drop = nn.Dropout(cfg.drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if cfg.weight_init != 'skip':
self.init_weights(cfg.weight_init)
if cfg.fix_init:
self.fix_init_weight()
def fix_init_weight(self) -> None:
"""Apply initialization weight fix with layer-wise scaling."""
def rescale(param: torch.Tensor, _layer_id: int) -> None:
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
"""Initialize model weights according to specified scheme.
Args:
mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
"""
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias), self)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
# Custom loading for the new model structure
from .vision_transformer import _load_weights as _orig_load_weights
def _load_weights_adapter(model, checkpoint_path, prefix=''):
"""Adapter function to handle the different model structure"""
state_dict = torch.load(checkpoint_path, map_location='cpu')
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Map original keys to new structure
for k in list(state_dict.keys()):
if k.startswith('cls_token'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('reg_token'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('pos_embed'):
state_dict['embeds.' + k] = state_dict.pop(k)
elif k.startswith('patch_embed'):
state_dict['embeds.' + k[12:]] = state_dict.pop(k)
return _orig_load_weights(model, state_dict, prefix)
_load_weights_adapter(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
"""Get set of parameter names that should not have weight decay applied.
Returns:
Set of parameter names to skip during weight decay
"""
skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
return skip_list
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
"""Get parameter group matcher for optimizer parameter grouping.
Args:
coarse: Whether to use coarse-grained grouping
Returns:
Dictionary mapping group names to regex patterns
"""
return dict(
stem=r'^embeds', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
"""Enable or disable gradient checkpointing for memory efficiency.
Args:
enable: Whether to enable gradient checkpointing
"""
self.grad_checkpointing = enable
if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
self.embeds.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
"""Get the classification head module.
Returns:
Classification head module
"""
return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Reset the classification head with new number of classes and pooling.
Args:
num_classes: Number of classes for new classification head
global_pool: Optional new global pooling type
"""
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
output_dict: bool = False,
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
attn_mask: Optional attention mask for masked attention
Returns:
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
"""
# FIXME unfinished / untested
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
if isinstance(x, Dict):
# Handle dictionary input from NaFlex collator
patch_coord = x['patch_coord']
patch_valid = x['patch_valid']
patches = x['patches']
assert False, 'WIP, patch mode needs more work'
else:
patches = x
height, width = x.shape[-2:]
H, W = self.embeds.dynamic_feat_size((height, width))
# Create attention mask if patch_type is provided and mask is not
if attn_mask is None and patch_valid is not None:
attn_mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype)
# Forward pass through embedding
x = self.embeds(patches, patch_coord=patch_coord)
x = self.norm_pre(x)
# Forward pass through blocks
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
if attn_mask is not None:
x = blk(x, attn_mask=attn_mask)
elif self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# Process intermediates
if self.num_prefix_tokens:
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
else:
prefix_tokens = None
if reshape:
# reshape to BCHW output format
intermediates = [
y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
for y in intermediates
]
# For dictionary output
if output_dict:
result_dict = {}
# Intermediates are always included
result_dict['image_intermediates'] = intermediates
if prefix_tokens is not None and return_prefix_tokens:
result_dict['image_intermediates_prefix'] = prefix_tokens
# Only include features if not intermediates_only
if not intermediates_only:
x_final = self.norm(x)
result_dict['image_features'] = x_final
return result_dict
# For non-dictionary output, maintain the original behavior
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(intermediates, prefix_tokens))
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def forward_features(
self,
x: torch.Tensor,
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn_mask is None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=x.dtype
)
# Pass through embedding module with patch coordinate/type support
x = self.embeds(x, patch_coord=patch_coord)
x = self.norm_pre(x)
# Apply transformer blocks with masked attention if mask provided
if attn_mask is not None:
# We need to apply blocks one by one with mask
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
elif self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def _pool(
self,
x: torch.Tensor,
pool_type: Optional[str] = None,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.attn_pool is not None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
symmetric=False,
q_len=1,
dtype=x.dtype,
)
if not self.pool_include_prefix:
x = x[:, self.num_prefix_tokens:]
x = self.attn_pool(x, attn_mask=attn_mask)
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_naflex(
x,
patch_valid,
pool_type=pool_type,
num_prefix_tokens=self.num_prefix_tokens,
reduce_include_prefix=self.pool_include_prefix,
)
return x
def forward_head(
self,
x: torch.Tensor,
pre_logits: bool = False,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self._pool(x, patch_valid=patch_valid)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(
self,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
patch_coord: Optional[torch.Tensor] = None,
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with optional NaFlex support.
Args:
x: Input tensor. Supported formats:
- [B, C, H, W] standard image input
- [B, N, P*P*C] pre-patchified tensor (flattened patches)
- [B, N, Ph, Pw, C] pre-patchified tensor (variable patch size)
- Dict from NaFlex collator
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
patch_valid: Optional patch validity indicators for NaFlex.
Returns:
Model output tensor.
"""
if isinstance(x, Dict):
# Handle dictionary input from NaFlex collator
patch_coord = x['patch_coord']
patch_valid = x['patch_valid']
patches = x['patches']
# DEBUG, reconstruct patches
# for i in range(len(patches)):
# patch = patches[i][patch_valid[i]]
# h = (patch_coord[i, :, 0].max() + 1).item()
# w = (patch_coord[i, :, 1].max() + 1).item()
# patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
# patch = patch.reshape(3, h*16, w*16)
# from torchvision.utils import save_image
# save_image(patch, f'patch_{i}.jpg', normalize=True)
else:
patches = x
# Create attention mask if patch_type is provided
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=patches.dtype,
)
# Forward features with mask
x = self.forward_features(
patches,
patch_coord=patch_coord,
patch_valid=patch_valid,
attn_mask=attn_mask,
)
# Pass mask to forward_head for masked pooling
x = self.forward_head(
x,
patch_valid=patch_valid,
)
return x
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
"""Function imported from vision_transformer.py to maintain compatibility"""
from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
return init_weights_vit_moco
else:
return init_weights_vit_timm
def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]:
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
num_cls_token = 0
num_reg_token = 0
if 'reg_token' in state_dict:
num_reg_token = state_dict['reg_token'].shape[1]
if 'cls_token' in state_dict:
num_cls_token = state_dict['cls_token'].shape[1]
num_prefix_tokens = num_cls_token + num_reg_token
# Original format is (1, N, C), need to reshape to (1, H, W, C)
num_patches = v.shape[1]
num_patches_no_prefix = num_patches - num_prefix_tokens
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
grid_size = math.sqrt(num_patches)
if (grid_size_no_prefix != grid_size
and (grid_size_no_prefix.is_integer() and not grid_size.is_integer())
):
# make a decision, did the pos_embed of the original include the prefix tokens?
num_patches = num_patches_no_prefix
cls_token_emb = v[:, 0:num_cls_token]
if cls_token_emb.numel():
state_dict['cls_token'] += cls_token_emb
reg_token_emb = v[:, num_cls_token:num_reg_token]
if reg_token_emb.numel():
state_dict['reg_token'] += reg_token_emb
v = v[:, num_prefix_tokens:]
grid_size = grid_size_no_prefix
grid_size = int(grid_size)
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
suffix = k[12:]
if suffix == 'proj.weight':
v = v.permute(0, 2, 3, 1).flatten(1)
new_key = 'embeds.' + suffix
out_dict[new_key] = v
else:
out_dict[k] = v
return out_dict
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 384, 384),
'pool_size': None,
'crop_pct': 1.0,
'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'first_conv': 'embeds.proj',
'classifier': 'head',
'license': 'apache-2.0',
**kwargs,
}
default_cfgs = generate_default_cfgs({
'naflexvit_base_patch16_gap.e300_s576_in1k': _cfg(
hf_hub_id='timm/',
),
'naflexvit_base_patch16_par_gap.e300_s576_in1k': _cfg(
hf_hub_id='timm/',
),
'naflexvit_base_patch16_parfac_gap.e300_s576_in1k': _cfg(
hf_hub_id='timm/',
),
'naflexvit_base_patch16_map.untrained': _cfg(),
'naflexvit_base_patch16_siglip.untrained': _cfg(),
'naflexvit_so400m_patch16_siglip.untrained': _cfg(),
})
def _create_naflexvit(variant: str, pretrained: bool = False, **kwargs) -> NaFlexVit:
out_indices = kwargs.pop('out_indices', 3)
cfg = kwargs.pop('cfg', NaFlexVitCfg())
cfg_field_names = {f.name for f in fields(NaFlexVitCfg)}
# pop in-place so the original kwargs is emptied of cfg-specific keys
cfg_updates = {k: kwargs.pop(k) for k in list(kwargs) if k in cfg_field_names}
if cfg_updates:
cfg = _overlay_kwargs(cfg, **cfg_updates)
model = build_model_with_cfg(
NaFlexVit, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
cfg=cfg,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model
def _create_naflexvit_from_classic(
variant: str,
pretrained: bool = False,
**kwargs,
) -> NaFlexVit:
"""Create FlexVit model from classic VisionTransformer configuration.
This function handles the parameter mapping and configuration logic needed
to create FlexVit models that are compatible with classic VisionTransformer
configurations and pretrained weights.
Args:
variant: Model variant name
pretrained: Whether to load pretrained weights
**kwargs: Classic VisionTransformer parameters
Returns:
FlexVit model instance
"""
# Remove VisionTransformer-specific parameters that don't apply to FlexVit
kwargs.pop('no_embed_class', None)
kwargs.pop('dynamic_img_size', None)
# Handle global pooling and fc_norm defaults that differ between ViT and FlexVit
gp = kwargs.pop('global_pool', 'token') # Original ViTs default to cls token pooling
fc_norm = kwargs.pop('fc_norm', None) # Original ViTs used fc_norm when not set and avg pooling used
if fc_norm is None and gp == 'avg':
fc_norm = True
# Set FlexVit-specific defaults that differ from VisionTransformer
flex_kwargs = {
'pos_embed_grid_size': None, # rely on img_size (// patch_size) that will be passed through
'class_token': kwargs.get('class_token', True),
'global_pool': gp,
'fc_norm': fc_norm,
**kwargs # User overrides take precedence
}
return _create_naflexvit(variant, pretrained, **flex_kwargs)
@register_model
def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-Base with NaFlex functionality and global average pooling.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
init_values=1e-5,
global_pool='avg',
reg_tokens=4,
fc_norm=True,
)
model = _create_naflexvit('naflexvit_base_patch16_gap', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_base_patch16_par_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
init_values=1e-5,
pos_embed_ar_preserving=True,
global_pool='avg',
reg_tokens=4,
fc_norm=True,
)
model = _create_naflexvit('naflexvit_base_patch16_par_gap', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_base_patch16_parfac_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
init_values=1e-5,
pos_embed_ar_preserving=True,
pos_embed='factorized',
global_pool='avg',
reg_tokens=4,
fc_norm=True,
)
model = _create_naflexvit('naflexvit_base_patch16_parfac_gap', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-Base with NaFlex functionality and MAP attention pooling.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
init_values=1e-5,
global_pool='map',
reg_tokens=1,
)
model = _create_naflexvit('naflexvit_base_patch16_map', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=832,
depth=21,
num_heads=13,
mlp_ratio=34/13,
init_values=1e-5,
qkv_bias=False,
reg_tokens=1,
global_pool='avg',
fc_norm=True,
)
model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_gap', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_so150m2_patch16_reg1_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=832,
depth=21,
num_heads=13,
mlp_ratio=34/13,
init_values=1e-5,
qkv_bias=False,
reg_tokens=1,
global_pool='map',
)
model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_map', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_base_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-Base with NaFlex functionality and SigLIP-style configuration.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
act_layer='gelu_tanh',
global_pool='map',
)
model = _create_naflexvit('naflexvit_base_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
return model
@register_model
def naflexvit_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
"""ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
"""
cfg = NaFlexVitCfg(
patch_size=16,
embed_dim=1152,
depth=27,
num_heads=16,
mlp_ratio=3.7362,
act_layer='gelu_tanh',
global_pool='map',
)
model = _create_naflexvit('naflexvit_so400m_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
return model