in picotron/model.py [0:0]
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_values = config.num_key_value_heads
self.head_dim = self.hidden_size//self.num_heads
assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size"
assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size"
self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism
self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.layer_idx = layer_idx
self.reset_parameters()