in levit_c.py [0:0]
def model_factory(C, D, X, N, drop_path, weights,
num_classes, distillation, pretrained, fuse):
embed_dim = [int(x) for x in C.split('_')]
num_heads = [int(x) for x in N.split('_')]
depth = [int(x) for x in X.split('_')]
act = torch.nn.Hardswish
model = LeViT(
patch_size=16,
embed_dim=embed_dim,
num_heads=num_heads,
key_dim=[D] * 3,
depth=depth,
attn_ratio=[2, 2, 2],
mlp_ratio=[2, 2, 2],
down_ops=[
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
],
attention_activation=act,
mlp_activation=act,
hybrid_backbone=b16(embed_dim[0], activation=act),
num_classes=num_classes,
drop_path=drop_path,
distillation=distillation
)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
weights, map_location='cpu')
d = checkpoint['model']
D = model.state_dict()
for k in d.keys():
if D[k].shape != d[k].shape:
d[k] = d[k][:, :, None, None]
model.load_state_dict(d)
if fuse:
utils.replace_batchnorm(model)
return model