in ViT4MNIST/vit_pytorch.py [0:0]
def __init__(self, dim, heads=8):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim)