def checkpoint_filter_fn()

in timm/models/naflexvit.py [0:0]


def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]:
    """Handle state dict conversion from original ViT to the new version with combined embedding."""
    from .vision_transformer import checkpoint_filter_fn as orig_filter_fn

    # Handle CombinedEmbed module pattern
    out_dict = {}
    for k, v in state_dict.items():
        # Convert tokens and embeddings to combined_embed structure
        if k == 'pos_embed':
            # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
            if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
                num_cls_token = 0
                num_reg_token = 0
                if 'reg_token' in state_dict:
                    num_reg_token = state_dict['reg_token'].shape[1]
                if 'cls_token' in state_dict:
                    num_cls_token = state_dict['cls_token'].shape[1]
                num_prefix_tokens = num_cls_token + num_reg_token

                # Original format is (1, N, C), need to reshape to (1, H, W, C)
                num_patches = v.shape[1]
                num_patches_no_prefix = num_patches - num_prefix_tokens
                grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
                grid_size = math.sqrt(num_patches)
                if (grid_size_no_prefix != grid_size
                        and (grid_size_no_prefix.is_integer() and not grid_size.is_integer())
                ):
                    # make a decision, did the pos_embed of the original include the prefix tokens?
                    num_patches = num_patches_no_prefix
                    cls_token_emb = v[:, 0:num_cls_token]
                    if cls_token_emb.numel():
                        state_dict['cls_token'] += cls_token_emb
                    reg_token_emb = v[:, num_cls_token:num_reg_token]
                    if reg_token_emb.numel():
                        state_dict['reg_token'] += reg_token_emb
                    v = v[:, num_prefix_tokens:]
                    grid_size = grid_size_no_prefix
                grid_size = int(grid_size)

                # Check if it's a perfect square for a standard grid
                if grid_size * grid_size == num_patches:
                    # Reshape from (1, N, C) to (1, H, W, C)
                    v = v.reshape(1, grid_size, grid_size, v.shape[2])
                else:
                    # Not a square grid, we need to get the actual dimensions
                    if hasattr(model.embeds.patch_embed, 'grid_size'):
                        h, w = model.embeds.patch_embed.grid_size
                        if h * w == num_patches:
                            # We have the right dimensions
                            v = v.reshape(1, h, w, v.shape[2])
                        else:
                            # Dimensions don't match, use interpolation
                            _logger.warning(
                                f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
                                f"Using default initialization and will resize in forward pass."
                            )
                            # Keep v as is, the forward pass will handle resizing

            out_dict['embeds.pos_embed'] = v
        elif k == 'cls_token':
            out_dict['embeds.cls_token'] = v
        elif k == 'reg_token':
            out_dict['embeds.reg_token'] = v
        # Convert patch_embed.X to embeds.patch_embed.X
        elif k.startswith('patch_embed.'):
            suffix = k[12:]
            if suffix == 'proj.weight':
                v = v.permute(0, 2, 3, 1).flatten(1)
            new_key = 'embeds.' + suffix
            out_dict[new_key] = v
        else:
            out_dict[k] = v

    return out_dict