in timm/models/convnext.py [0:0]
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
elif 'visual.head.mlp.fc1.weight' in state_dict:
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
return out_dict
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
if 'grn' in k:
k = k.replace('grn.beta', 'mlp.grn.bias')
k = k.replace('grn.gamma', 'mlp.grn.weight')
v = v.reshape(v.shape[-1])
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict