def forward()

in picotron/model.py [0:0]


    def forward(self, x, cos, sin, attention_mask=None, position_ids=None):
        batch_size, seq_length, hidden_dim = x.size()
        q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
        k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
        v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
        if os.getenv('FLASH_ATTEN', '1') != '1':
            q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim).transpose(1, 2)       # [batch_size, num_heads, seq_length, head_dim]
            k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_key_values, seq_length, head_dim]
            v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_key_values, seq_length, head_dim]
            q = apply_rotary_pos_emb(q, cos, sin)
            k = apply_rotary_pos_emb(k, cos, sin)
        else:
            q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim)       # [batch_size, seq_length, num_heads, head_dim]
            k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim)  # [batch_size, seq_length, num_key_values, head_dim]
            q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim]
            k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim]
            q = q.transpose(1, 2)                                                                   # [batch_size, num_heads, seq_length, head_dim]
            k = k.transpose(1, 2)                                                                   # [batch_size, num_key_values, seq_length, head_dim]
            v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2)   # [batch_size, num_key_values, seq_length, head_dim]
        
        k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
        v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
        
        causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 
        
        # TODO: replace everything with flex attention
        if os.getenv('CONTEXT_PARALLEL', '0') == '1':
            # Ring attention for context parallelism
            sm_scale = 1.0 / (q.size(-1) ** 0.5)
            out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
        elif os.getenv('FLASH_ATTEN', '1') == '1':
            # flash attention, this is faster! 
            out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 
        else:
            # Pytorch scaled dot product attention
            out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) # [batch_size, num_heads, seq_length, head_dim]
            out = out.transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
        
        out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim]
        out = self.out_proj(out) # [batch_size, seq_length, hidden_dim]
        return out