in timm/models/mvitv2.py [0:0]
def checkpoint_filter_fn(state_dict, model):
if 'stages.0.blocks.0.norm1.weight' in state_dict:
# native checkpoint, look for rel_pos interpolations
for k in state_dict.keys():
if 'rel_pos' in k:
rel_pos = state_dict[k]
dest_rel_pos_shape = model.state_dict()[k].shape
if rel_pos.shape[0] != dest_rel_pos_shape[0]:
rel_pos_resized = torch.nn.functional.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=dest_rel_pos_shape[0],
mode="linear",
)
state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
return state_dict
import re
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
depths = getattr(model, 'depths', None)
expand_attn = getattr(model, 'expand_attn', True)
assert depths is not None, 'model requires depth attribute to remap checkpoints'
depth_map = {}
block_idx = 0
for stage_idx, d in enumerate(depths):
depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
block_idx += d
out_dict = {}
for k, v in state_dict.items():
k = re.sub(
r'blocks\.(\d+)',
lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
k)
if expand_attn:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
else:
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
if 'head' in k:
k = k.replace('head.projection', 'head.fc')
out_dict[k] = v
return out_dict