in src/sal/models/skywork_o1_prm/io_utils.py [0:0]
def derive_step_rewards(rewards, reward_flags):
batch_size = rewards.shape[0]
batch_step_rewards = []
for i in range(batch_size):
rewards_indices = torch.nonzero(reward_flags[i] == 1).view(-1)
step_rewards = [
rewards[i][rewards_indices[j]].item() for j in range(len(rewards_indices))
]
batch_step_rewards.append(step_rewards)
return batch_step_rewards