in models/vision_transformer.py [0:0]
def forward(self, x):
x = self.conv(x) # extract patches
x = x.flatten(2) # flatten the patches into a single dimension
x = x.transpose(1, 2) # transpose to (batch_size, num_patches, hidden_dim)
# Add CLS token (according to original ViT Paper) and position embeddings
if self.cls_flag:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + self.position_embedding
return x