def checkpoint_filter_fn()

in timm/models/fastvit.py [0:0]


def checkpoint_filter_fn(state_dict, model):
    """ Remap original checkpoints -> timm """
    if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
        return state_dict  # non-original checkpoint, no remapping needed

    state_dict = state_dict.get('state_dict', state_dict)
    if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
        # remap MobileCLIP checkpoints
        prefix = 'image_encoder.model.'
    else:
        prefix = ''

    import re
    import bisect

    # find stage ends by locating downsample layers
    stage_ends = []
    for k, v in state_dict.items():
        match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
        if match:
            stage_ends.append(int(match.group(2)))
    stage_ends = list(sorted(set(stage_ends)))

    out_dict = {}
    for k, v in state_dict.items():
        if prefix:
            if prefix not in k:
                continue
            k = k.replace(prefix, '')

        # remap renamed layers
        k = k.replace('patch_embed', 'stem')
        k = k.replace('rbr_conv', 'conv_kxk')
        k = k.replace('rbr_scale', 'conv_scale')
        k = k.replace('rbr_skip', 'identity')
        k = k.replace('conv_exp', 'final_conv')  # to match byobnet, regnet, nfnet
        k = k.replace('lkb_origin', 'large_conv')
        k = k.replace('convffn', 'mlp')
        k = k.replace('se.reduce', 'se.fc1')
        k = k.replace('se.expand', 'se.fc2')
        k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
        if k.endswith('layer_scale'):
            k = k.replace('layer_scale', 'layer_scale.gamma')
        k = k.replace('dist_head', 'head_dist')
        if k.startswith('head.'):
            if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
                # if CLIP projection, map to head.fc w/ bias = zeros
                k = k.replace('head.proj', 'head.fc.weight')
                v = v.T
                out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
            else:
                k = k.replace('head.', 'head.fc.')

        # remap flat sequential network to stages
        match = re.match(r'^network\.(\d+)', k)
        stage_idx, net_idx = None, None
        if match:
            net_idx = int(match.group(1))
            stage_idx = bisect.bisect_right(stage_ends, net_idx)
        if stage_idx is not None:
            net_prefix = f'network.{net_idx}'
            stage_prefix = f'stages.{stage_idx}'
            if net_prefix + '.proj' in k:
                k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
            elif net_prefix + '.pe' in k:
                k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
            else:
                k = k.replace(net_prefix, stage_prefix + '.blocks')

        out_dict[k] = v
    return out_dict