def checkpoint_filter_fn()

in timm/models/mvitv2.py [0:0]


def checkpoint_filter_fn(state_dict, model):
    if 'stages.0.blocks.0.norm1.weight' in state_dict:
        # native checkpoint, look for rel_pos interpolations
        for k in state_dict.keys():
            if 'rel_pos' in k:
                rel_pos = state_dict[k]
                dest_rel_pos_shape = model.state_dict()[k].shape
                if rel_pos.shape[0] != dest_rel_pos_shape[0]:
                    rel_pos_resized = torch.nn.functional.interpolate(
                        rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
                        size=dest_rel_pos_shape[0],
                        mode="linear",
                    )
                    state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
        return state_dict

    import re
    if 'model_state' in state_dict:
        state_dict = state_dict['model_state']

    depths = getattr(model, 'depths', None)
    expand_attn = getattr(model, 'expand_attn', True)
    assert depths is not None, 'model requires depth attribute to remap checkpoints'
    depth_map = {}
    block_idx = 0
    for stage_idx, d in enumerate(depths):
        depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
        block_idx += d

    out_dict = {}
    for k, v in state_dict.items():
        k = re.sub(
            r'blocks\.(\d+)',
            lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
            k)

        if expand_attn:
            k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
        else:
            k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
        if 'head' in k:
            k = k.replace('head.projection', 'head.fc')
        out_dict[k] = v

    return out_dict