in models.py [0:0]
def convit_small(pretrained=False, **kwargs):
num_heads = 9
kwargs['embed_dim'] *= num_heads
model = VisionTransformer(
num_heads=num_heads,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/convit/convit_small.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint)
return model