in vissl/models/trunks/vision_transformer.py [0:0]
def __init__(self, model_config: AttrDict, model_name: str):
super().__init__()
assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported"
trunk_config = copy.deepcopy(model_config.TRUNK.VISION_TRANSFORMERS)
logging.info("Building model: Vision Transformer from yaml config")
# Hacky workaround
trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})
img_size = trunk_config.image_size
patch_size = trunk_config.patch_size
in_chans = 3
embed_dim = trunk_config.hidden_dim
depth = trunk_config.num_layers
num_heads = trunk_config.num_heads
mlp_ratio = 4.0
qkv_bias = trunk_config.qkv_bias
qk_scale = trunk_config.qk_scale
drop_rate = trunk_config.dropout_rate
attn_drop_rate = trunk_config.attention_dropout_rate
drop_path_rate = trunk_config.drop_path_rate
hybrid_backbone_string = None
# TODO Implement hybrid backbones
if "HYBRID" in trunk_config.keys():
hybrid_backbone_string = trunk_config.HYBRID
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
# TODO : Enable Hybrid Backbones
if hybrid_backbone_string:
self.patch_embed = globals()[hybrid_backbone_string](
out_dim=embed_dim, img_size=img_size
)
# if hybrid_backbone is not None:
# self.patch_embed = HybridEmbed(
# hybrid_backbone,
# img_size=img_size,
# in_chans=in_chans,
# embed_dim=embed_dim,
# )
else:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# NOTE as per official impl, we could have a pre-logits
# representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)