in detection/backbone/xcit.py [0:0]
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
cls_attn_layers=2, use_pos=True, eta=None, tokens_norm=False,
out_indices=[3, 5, 7, 11]):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
cls_attn_layers: (int) Depth of Class attention layers
use_pos: (bool) whether to use positional encoding
eta: (float) layerscale initialization value
tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
out_indices: (list) Indices of layers from which FPN features are extracted
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
patch_size=patch_size)
num_patches = self.patch_embed.num_patches
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
XCABlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
for i in range(depth)])
self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
self.use_pos = use_pos
self.out_indices = out_indices
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)