in easycv/models/backbones/shuffle_transformer.py [0:0]
def __init__(self,
img_size=224,
in_chans=3,
num_classes=1000,
token_dim=32,
embed_dim=96,
mlp_ratio=4.,
layers=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
relative_pos_embedding=True,
shuffle=True,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
has_pos_embed=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.has_pos_embed = has_pos_embed
dims = [i * 32 for i in num_heads]
self.to_token = PatchEmbedding(
inter_channel=token_dim, out_channels=embed_dim)
num_patches = (img_size * img_size) // 16
if self.has_pos_embed:
raise NotImplementedError
# self.pos_embed = nn.Parameter(
# data=get_sinusoid_encoding(
# n_position=num_patches, d_hid=embed_dim),
# requires_grad=False)
# self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 4)
] # stochastic depth decay rule
self.stage1 = StageModule(
layers[0],
embed_dim,
dims[0],
num_heads[0],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[0],
relative_pos_embedding=relative_pos_embedding)
self.stage2 = StageModule(
layers[1],
dims[0],
dims[1],
num_heads[1],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[1],
relative_pos_embedding=relative_pos_embedding)
self.stage3 = StageModule(
layers[2],
dims[1],
dims[2],
num_heads[2],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[2],
relative_pos_embedding=relative_pos_embedding)
self.stage4 = StageModule(
layers[3],
dims[2],
dims[3],
num_heads[3],
window_size=window_size,
shuffle=shuffle,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[3],
relative_pos_embedding=relative_pos_embedding)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Classifier head
self.head = nn.Linear(
dims[3], num_classes) if num_classes > 0 else nn.Identity()