in timm/models/fastvit.py [0:0]
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
state_dict = state_dict.get('state_dict', state_dict)
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# remap MobileCLIP checkpoints
prefix = 'image_encoder.model.'
else:
prefix = ''
import re
import bisect
# find stage ends by locating downsample layers
stage_ends = []
for k, v in state_dict.items():
match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
if match:
stage_ends.append(int(match.group(2)))
stage_ends = list(sorted(set(stage_ends)))
out_dict = {}
for k, v in state_dict.items():
if prefix:
if prefix not in k:
continue
k = k.replace(prefix, '')
# remap renamed layers
k = k.replace('patch_embed', 'stem')
k = k.replace('rbr_conv', 'conv_kxk')
k = k.replace('rbr_scale', 'conv_scale')
k = k.replace('rbr_skip', 'identity')
k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
k = k.replace('lkb_origin', 'large_conv')
k = k.replace('convffn', 'mlp')
k = k.replace('se.reduce', 'se.fc1')
k = k.replace('se.expand', 'se.fc2')
k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
if k.endswith('layer_scale'):
k = k.replace('layer_scale', 'layer_scale.gamma')
k = k.replace('dist_head', 'head_dist')
if k.startswith('head.'):
if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
# if CLIP projection, map to head.fc w/ bias = zeros
k = k.replace('head.proj', 'head.fc.weight')
v = v.T
out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
else:
k = k.replace('head.', 'head.fc.')
# remap flat sequential network to stages
match = re.match(r'^network\.(\d+)', k)
stage_idx, net_idx = None, None
if match:
net_idx = int(match.group(1))
stage_idx = bisect.bisect_right(stage_ends, net_idx)
if stage_idx is not None:
net_prefix = f'network.{net_idx}'
stage_prefix = f'stages.{stage_idx}'
if net_prefix + '.proj' in k:
k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
elif net_prefix + '.pe' in k:
k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
else:
k = k.replace(net_prefix, stage_prefix + '.blocks')
out_dict[k] = v
return out_dict