def derive_step_rewards()

in src/sal/models/skywork_o1_prm/io_utils.py [0:0]


def derive_step_rewards(rewards, reward_flags):
    batch_size = rewards.shape[0]
    batch_step_rewards = []
    for i in range(batch_size):
        rewards_indices = torch.nonzero(reward_flags[i] == 1).view(-1)
        step_rewards = [
            rewards[i][rewards_indices[j]].item() for j in range(len(rewards_indices))
        ]
        batch_step_rewards.append(step_rewards)
    return batch_step_rewards