in models/vision_transformer.py [0:0]
def forward(self, x: torch.Tensor):
assert x.ndim == 4, "Unexpected input shape"
n, c, h, w = x.shape
p = self.patch_size
assert h == w == self.image_size
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> ((n_h * n_w), n, hidden_dim)
# the self attention layer expects inputs in the format (S, N, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(2, 0, 1)
if self.classifier == "token":
# expand the class token to the full batch
batch_class_token = self.class_token.expand(-1, n, -1)
x = torch.cat([batch_class_token, x], dim=0)
x = self.encoder(x)
if self.classifier == "token":
# just return the output for the class token
x = x[0, :, :]
else:
x = x.mean(dim=0)
x = self.trunk_output(x)
if self.head is not None:
x = self.head(x)
return x