in vissl/models/trunks/convit.py [0:0]
def __init__(self, model_config, model_name):
super().__init__()
trunk_config = copy.deepcopy(model_config.TRUNK.CONVIT)
trunk_config.update(model_config.TRUNK.VISION_TRANSFORMERS)
logging.info("Building model: ConViT from yaml config")
# Hacky workaround
trunk_config = AttrDict({k.lower(): v for k, v in trunk_config.items()})
image_size = trunk_config.image_size
patch_size = trunk_config.patch_size
classifier = trunk_config.classifier
assert image_size % patch_size == 0, "Input shape indivisible by patch size"
assert classifier in ["token", "gap"], "Unexpected classifier mode"
n_gpsa_layers = trunk_config.n_gpsa_layers
class_token_in_local_layers = trunk_config.class_token_in_local_layers
mlp_dim = trunk_config.mlp_dim
embed_dim = trunk_config.hidden_dim
locality_dim = trunk_config.locality_dim
attention_dropout_rate = trunk_config.attention_dropout_rate
dropout_rate = trunk_config.dropout_rate
drop_path_rate = trunk_config.drop_path_rate
num_layers = trunk_config.num_layers
locality_strength = trunk_config.locality_strength
num_heads = trunk_config.num_heads
qkv_bias = trunk_config.qkv_bias
qk_scale = trunk_config.qk_scale
use_local_init = trunk_config.use_local_init
hybrid_backbone = None
if "hybrid" in trunk_config.keys():
hybrid_backbone = trunk_config.hybrid
in_chans = 3
# TODO: Make this configurable
norm_layer = nn.LayerNorm
self.classifier = classifier
self.n_gpsa_layers = n_gpsa_layers
self.class_token_in_local_layers = class_token_in_local_layers
# For consistency with other models
self.num_features = self.embed_dim = self.hidden_dim = embed_dim
self.locality_dim = locality_dim
# Hybrid backbones not tested
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone,
img_size=image_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
else:
self.patch_embed = PatchEmbed(
img_size=image_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
seq_length = (image_size // patch_size) ** 2
self.seq_length = seq_length
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embedding = nn.Parameter(torch.zeros(1, seq_length, embed_dim))
self.pos_drop = nn.Dropout(p=dropout_rate)
if class_token_in_local_layers:
seq_length += 1
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
layers = []
for i in range(num_layers):
if i < self.n_gpsa_layers:
if locality_strength > 0:
layer_locality_strength = locality_strength
else:
layer_locality_strength = 1 / (i + 1)
layers.append(
AttentionBlock(
attention_module=GPSA,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
drop_path_rate=dpr[i],
norm_layer=norm_layer,
locality_strength=layer_locality_strength,
locality_dim=self.locality_dim,
use_local_init=use_local_init,
)
)
else:
layers.append(
AttentionBlock(
attention_module=SelfAttention,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
drop_path_rate=dpr[i],
norm_layer=norm_layer,
)
)
self.blocks = nn.ModuleList(layers)
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embedding, std=0.02)
trunc_normal_(self.class_token, std=0.02)
self.apply(self._init_weights)