in vits.py [0:0]
def __init__(self, stop_grad_conv1=False, **kwargs):
super().__init__(**kwargs)
# Use fixed 2D sin-cos position embedding
self.build_2d_sincos_position_embedding()
# weight initialization
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
else:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.cls_token, std=1e-6)
if isinstance(self.patch_embed, PatchEmbed):
# xavier_uniform initialization
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
nn.init.zeros_(self.patch_embed.proj.bias)
if stop_grad_conv1:
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False