in models/vision_transformer.py [0:0]
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv_proj(x)
q, k, v = qkv.split(C, dim=2)
# Reshape [B, T, C] -> [B, T, n_heads, head_dim] and transpose -> [B, n_heads, T, head_dim]
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
if self.sdpa:
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False # ViT attention is bidirectional
)
else:
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
y = attn @ v # (B, n_heads, T, T) x (B, n_heads, T, head_dim) -> (B, n_heads, T, head_dim)
# Transpose back from [B, n_heads, T, head_dim] to [B, T, n_heads * head_dim] and combine all heads to [B, T, C]
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
y = self.resid_dropout(y)
return y