def _filter_fn()

in timm/models/regnet.py [0:0]


def _filter_fn(state_dict: Dict[str, Any]) -> Dict[str, Any]:
    """Filter and remap state dict keys for compatibility.

    Args:
        state_dict: Raw state dictionary.

    Returns:
        Filtered state dictionary.
    """
    state_dict = state_dict.get('model', state_dict)
    replaces = [
        ('f.a.0', 'conv1.conv'),
        ('f.a.1', 'conv1.bn'),
        ('f.b.0', 'conv2.conv'),
        ('f.b.1', 'conv2.bn'),
        ('f.final_bn', 'conv3.bn'),
        ('f.se.excitation.0', 'se.fc1'),
        ('f.se.excitation.2', 'se.fc2'),
        ('f.se', 'se'),
        ('f.c.0', 'conv3.conv'),
        ('f.c.1', 'conv3.bn'),
        ('f.c', 'conv3.conv'),
        ('proj.0', 'downsample.conv'),
        ('proj.1', 'downsample.bn'),
        ('proj', 'downsample.conv'),
    ]
    if 'classy_state_dict' in state_dict:
        # classy-vision & vissl (SEER) weights
        import re
        state_dict = state_dict['classy_state_dict']['base_model']['model']
        out = {}
        for k, v in state_dict['trunk'].items():
            k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv')
            k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn')
            k = re.sub(
                r'^_feature_blocks.res\d.block(\d)-(\d+)',
                lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
            k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
            for s, r in replaces:
                k = k.replace(s, r)
            out[k] = v
        for k, v in state_dict['heads'].items():
            if 'projection_head' in k or 'prototypes' in k:
                continue
            k = k.replace('0.clf.0', 'head.fc')
            out[k] = v
        return out
    if 'stem.0.weight' in state_dict:
        # torchvision weights
        import re
        out = {}
        for k, v in state_dict.items():
            k = k.replace('stem.0', 'stem.conv')
            k = k.replace('stem.1', 'stem.bn')
            k = re.sub(
                r'trunk_output.block(\d)\.block(\d+)\-(\d+)',
                lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k)
            for s, r in replaces:
                k = k.replace(s, r)
            k = k.replace('fc.', 'head.fc.')
            out[k] = v
        return out
    return state_dict