in ViT4MNIST/vit_pytorch.py [0:0]
def __init__(self, dim, depth, heads, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
]))