in ViT4MNIST/vit_pytorch.py [0:0]
def forward(self, img, mask=None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)