in timm/models/vision_transformer.py [0:0]
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None:
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
if load_bfloat16:
import jax.numpy as jnp
import ml_dtypes
def _n2p(_w, t=True, idx=None):
if idx is not None:
_w = _w[idx]
if load_bfloat16:
_w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32)
_w = np.array(_w)
if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1:
_w = _w.flatten()
if t:
if _w.ndim == 4:
_w = _w.transpose([3, 2, 0, 1])
elif _w.ndim == 3:
_w = _w.transpose([2, 0, 1])
elif _w.ndim == 2:
_w = _w.transpose([1, 0])
_w = torch.from_numpy(_w)
return _w
if load_bfloat16:
w = jnp.load(checkpoint_path)
else:
w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False
big_vision = False
if not prefix:
if 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
elif 'params/embedding/kernel' in w:
prefix = 'params/'
big_vision = True
elif 'params/img/embedding/kernel' in w:
prefix = 'params/img/'
big_vision = True
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
model.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
if big_vision:
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
else:
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if (isinstance(model.head, nn.Linear) and
f'{prefix}head/bias' in w and
model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
if model.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
model.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
model.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()):
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))