in timm/models/regnet.py [0:0]
def _filter_fn(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Filter and remap state dict keys for compatibility.
Args:
state_dict: Raw state dictionary.
Returns:
Filtered state dictionary.
"""
state_dict = state_dict.get('model', state_dict)
replaces = [
('f.a.0', 'conv1.conv'),
('f.a.1', 'conv1.bn'),
('f.b.0', 'conv2.conv'),
('f.b.1', 'conv2.bn'),
('f.final_bn', 'conv3.bn'),
('f.se.excitation.0', 'se.fc1'),
('f.se.excitation.2', 'se.fc2'),
('f.se', 'se'),
('f.c.0', 'conv3.conv'),
('f.c.1', 'conv3.bn'),
('f.c', 'conv3.conv'),
('proj.0', 'downsample.conv'),
('proj.1', 'downsample.bn'),
('proj', 'downsample.conv'),
]
if 'classy_state_dict' in state_dict:
# classy-vision & vissl (SEER) weights
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']
out = {}
for k, v in state_dict['trunk'].items():
k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv')
k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn')
k = re.sub(
r'^_feature_blocks.res\d.block(\d)-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
for s, r in replaces:
k = k.replace(s, r)
out[k] = v
for k, v in state_dict['heads'].items():
if 'projection_head' in k or 'prototypes' in k:
continue
k = k.replace('0.clf.0', 'head.fc')
out[k] = v
return out
if 'stem.0.weight' in state_dict:
# torchvision weights
import re
out = {}
for k, v in state_dict.items():
k = k.replace('stem.0', 'stem.conv')
k = k.replace('stem.1', 'stem.bn')
k = re.sub(
r'trunk_output.block(\d)\.block(\d+)\-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(3)) + 1}', k)
for s, r in replaces:
k = k.replace(s, r)
k = k.replace('fc.', 'head.fc.')
out[k] = v
return out
return state_dict