in timm/models/naflexvit.py [0:0]
def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]:
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
# Handle CombinedEmbed module pattern
out_dict = {}
for k, v in state_dict.items():
# Convert tokens and embeddings to combined_embed structure
if k == 'pos_embed':
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
num_cls_token = 0
num_reg_token = 0
if 'reg_token' in state_dict:
num_reg_token = state_dict['reg_token'].shape[1]
if 'cls_token' in state_dict:
num_cls_token = state_dict['cls_token'].shape[1]
num_prefix_tokens = num_cls_token + num_reg_token
# Original format is (1, N, C), need to reshape to (1, H, W, C)
num_patches = v.shape[1]
num_patches_no_prefix = num_patches - num_prefix_tokens
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
grid_size = math.sqrt(num_patches)
if (grid_size_no_prefix != grid_size
and (grid_size_no_prefix.is_integer() and not grid_size.is_integer())
):
# make a decision, did the pos_embed of the original include the prefix tokens?
num_patches = num_patches_no_prefix
cls_token_emb = v[:, 0:num_cls_token]
if cls_token_emb.numel():
state_dict['cls_token'] += cls_token_emb
reg_token_emb = v[:, num_cls_token:num_reg_token]
if reg_token_emb.numel():
state_dict['reg_token'] += reg_token_emb
v = v[:, num_prefix_tokens:]
grid_size = grid_size_no_prefix
grid_size = int(grid_size)
# Check if it's a perfect square for a standard grid
if grid_size * grid_size == num_patches:
# Reshape from (1, N, C) to (1, H, W, C)
v = v.reshape(1, grid_size, grid_size, v.shape[2])
else:
# Not a square grid, we need to get the actual dimensions
if hasattr(model.embeds.patch_embed, 'grid_size'):
h, w = model.embeds.patch_embed.grid_size
if h * w == num_patches:
# We have the right dimensions
v = v.reshape(1, h, w, v.shape[2])
else:
# Dimensions don't match, use interpolation
_logger.warning(
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
f"Using default initialization and will resize in forward pass."
)
# Keep v as is, the forward pass will handle resizing
out_dict['embeds.pos_embed'] = v
elif k == 'cls_token':
out_dict['embeds.cls_token'] = v
elif k == 'reg_token':
out_dict['embeds.reg_token'] = v
# Convert patch_embed.X to embeds.patch_embed.X
elif k.startswith('patch_embed.'):
suffix = k[12:]
if suffix == 'proj.weight':
v = v.permute(0, 2, 3, 1).flatten(1)
new_key = 'embeds.' + suffix
out_dict[new_key] = v
else:
out_dict[k] = v
return out_dict