in models/vision_transformer.py [0:0]
def forward(self, x):
x = self.patch_embedding(x)
x = self.dropout(x)
for block in self.blocks:
x = block(x)
if self.cls_flag:
x = self.layer_norm(x[:, 0])
else:
x = self.layer_norm(x)
#x = x.mean(dim=1)
return x