in grok/transformer.py [0:0]
def __init__(self, d_model: int, heads: int, weight_noise: float = 0.0) -> None:
super().__init__()
d_key = int(d_model / heads)
attn_heads = [
AttentionHead(d_model, d_key, weight_noise=weight_noise)
for _ in range(heads)
]
self.attn_heads = nn.ModuleList(attn_heads)
self.Wo = Linear(d_model, d_model, bias=False, weight_noise=weight_noise)