in xcit.py [0:0]
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
cls_attn_layers: (int) Depth of Class attention layers
use_pos: (bool) whether to use positional encoding
eta: (float) layerscale initialization value
tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
patch_size=patch_size)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
XCABlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
for i in range(depth)])
self.cls_attn_blocks = nn.ModuleList([
ClassAttentionBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer,
eta=eta, tokens_norm=tokens_norm)
for i in range(cls_attn_layers)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
self.use_pos = use_pos
# Classifier head
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)