def forward()

in picotron/pipeline_parallel/pipeline_parallel.py [0:0]


    def forward(self, input_ids, position_ids, hidden_states):
        """
        Forward pass for this pipeline stage.
        Processes input through assigned layers and passes result to next stage.
        """
        x = hidden_states if hidden_states is not None else input_ids
        x = self.embedding(x)
        for layer in self.decoder_layers.values():
            x = layer(x, position_ids=position_ids)
        x = self.final_norm(x)
        return self.final_proj(x)