in src/sal/models/skywork_o1_prm/prm_model.py [0:0]
def forward(self, hidden_states):
output = self.dropout(hidden_states)
# For now force upcast in fp32 if needed. Let's keep the
# output in fp32 for numerical stability.
if output.dtype != self.summary.weight.dtype:
output = output.to(self.summary.weight.dtype)
output = self.summary(output)
return output