in cvnets/modules/mobilevit_block.py [0:0]
def __init__(self, opts, in_channels: int, transformer_dim: int, ffn_dim: int,
n_transformer_blocks: Optional[int] = 2,
head_dim: Optional[int] = 32, attn_dropout: Optional[float] = 0.1,
dropout: Optional[int] = 0.1, ffn_dropout: Optional[int] = 0.1, patch_h: Optional[int] = 8,
patch_w: Optional[int] = 8, transformer_norm_layer: Optional[str] = "layer_norm",
conv_ksize: Optional[int] = 3,
dilation: Optional[int] = 1, var_ffn: Optional[bool] = False,
no_fusion: Optional[bool] = False,
*args, **kwargs):
conv_3x3_in = ConvLayer(
opts=opts, in_channels=in_channels, out_channels=in_channels,
kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, dilation=dilation
)
conv_1x1_in = ConvLayer(
opts=opts, in_channels=in_channels, out_channels=transformer_dim,
kernel_size=1, stride=1, use_norm=False, use_act=False
)
conv_1x1_out = ConvLayer(
opts=opts, in_channels=transformer_dim, out_channels=in_channels,
kernel_size=1, stride=1, use_norm=True, use_act=True
)
conv_3x3_out = None
if not no_fusion:
conv_3x3_out = ConvLayer(
opts=opts, in_channels=2 * in_channels, out_channels=in_channels,
kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True
)
super(MobileViTBlock, self).__init__()
self.local_rep = nn.Sequential()
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
assert transformer_dim % head_dim == 0
num_heads = transformer_dim // head_dim
ffn_dims = [ffn_dim] * n_transformer_blocks
global_rep = [
TransformerEncoder(opts=opts, embed_dim=transformer_dim, ffn_latent_dim=ffn_dims[block_idx], num_heads=num_heads,
attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout,
transformer_norm_layer=transformer_norm_layer)
for block_idx in range(n_transformer_blocks)
]
global_rep.append(
get_normalization_layer(opts=opts, norm_type=transformer_norm_layer, num_features=transformer_dim)
)
self.global_rep = nn.Sequential(*global_rep)
self.conv_proj = conv_1x1_out
self.fusion = conv_3x3_out
self.patch_h = patch_h
self.patch_w = patch_w
self.patch_area = self.patch_w * self.patch_h
self.cnn_in_dim = in_channels
self.cnn_out_dim = transformer_dim
self.n_heads = num_heads
self.ffn_dim = ffn_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.ffn_dropout = ffn_dropout
self.dilation = dilation
self.ffn_max_dim = ffn_dims[0]
self.ffn_min_dim = ffn_dims[-1]
self.var_ffn = var_ffn
self.n_blocks = n_transformer_blocks
self.conv_ksize = conv_ksize