in models/vision_transformer.py [0:0]
def __init__(self, cfg):
super().__init__()
self.img_size = cfg.vit_img_size
self.patch_size = cfg.vit_patch_size
self.num_patches = (self.img_size // self.patch_size) ** 2
self.cls_flag = cfg.vit_cls_flag
self.embd_dim = cfg.vit_hidden_dim
# Conv layer to extract the patches
self.conv = nn.Conv2d(
in_channels=3,
out_channels=self.embd_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
if self.cls_flag:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embd_dim))
self.position_embedding = nn.Parameter(torch.rand(1, self.num_patches + 1, self.embd_dim))
else:
self.position_embedding = nn.Parameter(torch.rand(1, self.num_patches, self.embd_dim))