in models/language_model.py [0:0]
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.lm_use_tokens = cfg.lm_use_tokens
self.lm_tie_weights = cfg.lm_tie_weights
self.token_embedding = nn.Embedding(cfg.lm_vocab_size, cfg.lm_hidden_dim)
self.rotary_embd = RotaryEmbedding(cfg)
self.blocks = nn.ModuleList([
LanguageModelBlock(cfg) for _ in range(cfg.lm_n_blocks)
])
self.norm = RMSNorm(cfg) # Final Norm
self.head = nn.Linear(cfg.lm_hidden_dim, cfg.lm_vocab_size, bias=False)
if self.lm_tie_weights:
self.head.weight = self.token_embedding.weight
self.apply(self._init_weights)