in patchconvnet_models.py [0:0]
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=1, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
block_layers = Layer_scale_init_Block,
block_layers_token = Layer_scale_init_Block_only_token,
Patch_layer=ConvStem,act_layer=nn.GELU,
Attention_block = Conv_blocks_se ,
dpr_constant=True,init_scale=1e-4,
Attention_block_token_only=Learned_Aggregation_Layer,
Mlp_block_token_only= Mlp,
depth_token_only=1,
mlp_ratio_clstk = 3.0,
multiclass=False):
super().__init__()
self.multiclass = multiclass
self.patch_size = patch_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = Patch_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
if not self.multiclass:
self.cls_token = nn.Parameter(torch.zeros(1, 1, int(embed_dim)))
else:
self.cls_token = nn.Parameter(torch.zeros(1, num_classes, int(embed_dim)))
if not dpr_constant:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
else:
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
block_layers(
dim=embed_dim, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer,Attention_block=Attention_block,init_values=init_scale)
for i in range(depth)])
self.blocks_token_only = nn.ModuleList([
block_layers_token(
dim=int(embed_dim), num_heads=num_heads, mlp_ratio=mlp_ratio_clstk,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer,Attention_block=Attention_block_token_only,
Mlp_block=Mlp_block_token_only,init_values=init_scale)
for i in range(depth_token_only)])
self.norm = norm_layer(int(embed_dim))
self.total_len = depth_token_only+depth
self.feature_info = [dict(num_chs=int(embed_dim ), reduction=0, module='head')]
if not self.multiclass:
self.head = nn.Linear(int(embed_dim), num_classes) if num_classes > 0 else nn.Identity()
else:
self.head = nn.ModuleList([nn.Linear(int(embed_dim), 1) for _ in range(num_classes)])
self.rescale = .02
trunc_normal_(self.cls_token, std=self.rescale)
self.apply(self._init_weights)