in picotron/model.py [0:0]
def __init__(self, config, layer_idx):
super().__init__()
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention = Attention(config, layer_idx = layer_idx)
self.mlp = MLP(config)
self.layer_idx = layer_idx
head_dim = config.hidden_size // config.num_attention_heads
self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim]
# For context parallelism, we split the input. We need to get the correct cos and sin for each split
self.cos, self.sin = context_parallel.update_rope_for_context_parallel(self.cos, self.sin)