in timm/models/levit.py [0:0]
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
embed_dim=(192,),
key_dim=64,
depth=(12,),
num_heads=(3,),
attn_ratio=2.,
mlp_ratio=2.,
stem_backbone=None,
stem_stride=None,
stem_type='s16',
down_op='subsample',
act_layer='hard_swish',
attn_act_layer=None,
use_conv=False,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.):
super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer or act_layer)
self.use_conv = use_conv
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.head_hidden_size = embed_dim[-1]
self.embed_dim = embed_dim
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
num_stages = len(embed_dim)
assert len(depth) == num_stages
num_heads = to_ntuple(num_stages)(num_heads)
attn_ratio = to_ntuple(num_stages)(attn_ratio)
mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
if stem_backbone is not None:
assert stem_stride >= 2
self.stem = stem_backbone
stride = stem_stride
else:
assert stem_type in ('s16', 's8')
if stem_type == 's16':
self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer)
else:
self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer)
stride = self.stem.stride
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
in_dim = embed_dim[0]
stages = []
for i in range(num_stages):
stage_stride = 2 if i > 0 else 1
stages += [LevitStage(
in_dim,
embed_dim[i],
key_dim,
depth=depth[i],
num_heads=num_heads[i],
attn_ratio=attn_ratio[i],
mlp_ratio=mlp_ratio[i],
act_layer=act_layer,
attn_act_layer=attn_act_layer,
resolution=resolution,
use_conv=use_conv,
downsample=down_op if stage_stride == 2 else '',
drop_path=drop_path_rate
)]
stride *= stage_stride
resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
in_dim = embed_dim[i]
self.stages = nn.Sequential(*stages)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity()