in timm/models/swin_transformer.py [0:0]
def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> Dict[str, torch.Tensor]:
"""Convert patch embedding weight from manual patchify + linear proj to conv.
Args:
state_dict: State dictionary from checkpoint.
model: Model instance.
Returns:
Filtered state dictionary.
"""
old_weights = True
if 'head.fc.weight' in state_dict:
old_weights = False
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'attn_mask')]):
continue # skip buffers that should not be persistent
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
if v.shape[-2] != H or v.shape[-1] != W:
v = resample_patch_embed(
v,
(H, W),
interpolation='bicubic',
antialias=True,
verbose=True,
)
if k.endswith('relative_position_bias_table'):
m = model.get_submodule(k[:-29])
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
v = resize_rel_pos_bias_table(
v,
new_window_size=m.window_size,
new_bias_shape=m.relative_position_bias_table.shape,
)
if old_weights:
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict