in models/language_model.py [0:0]
def __init__(self, cfg):
super().__init__()
self.n_heads = cfg.lm_n_heads
self.n_kv_heads = cfg.lm_n_kv_heads
self.embd_dim = cfg.lm_hidden_dim
self.dropout = cfg.lm_dropout
assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads"
self.n_kv_groups = self.n_heads // self.n_kv_heads
self.head_dim = self.embd_dim // self.n_heads
self.q_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
self.k_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
self.v_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
self.attn_dropout = nn.Dropout(self.dropout)
self.resid_dropout = nn.Dropout(self.dropout)
# Use scaled dot product attention if available
self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.sdpa:
print("Warning: scaled dot product attention not available, using standard attention in LM.")