in picotron/pipeline_parallel/pipeline_parallel.py [0:0]
def forward(self, input_ids, position_ids, hidden_states):
"""
Forward pass for this pipeline stage.
Processes input through assigned layers and passes result to next stage.
"""
x = hidden_states if hidden_states is not None else input_ids
x = self.embedding(x)
for layer in self.decoder_layers.values():
x = layer(x, position_ids=position_ids)
x = self.final_norm(x)
return self.final_proj(x)