in models/vision_transformer.py [0:0]
def __init__(self, cfg):
super().__init__()
self.n_heads = cfg.vit_n_heads
self.embd_dim = cfg.vit_hidden_dim
assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads"
self.head_dim = self.embd_dim // self.n_heads
self.dropout = cfg.vit_dropout
# Combined projections for all heads
self.qkv_proj = nn.Linear(self.embd_dim, 3 * self.embd_dim, bias=True)
self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=True)
# Dropout layers
self.attn_dropout = nn.Dropout(self.dropout)
self.resid_dropout = nn.Dropout(self.dropout)
# Use scaled dot product attention if available
self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.sdpa:
print("Warning: scaled dot product attention not available. Using standard attention in ViT.")