in src/sal/models/skywork_o1_prm/io_utils.py [0:0]
def prepare_batch_input_for_model(input_ids, reward_flags, pad_token_id):
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor(ids) for ids in input_ids],
batch_first=True,
padding_value=pad_token_id,
)
padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor([1] * len(ids)) for ids in input_ids],
batch_first=True,
padding_value=0,
)
padded_reward_flags = torch.nn.utils.rnn.pad_sequence(
[torch.LongTensor(reward_flag) for reward_flag in reward_flags],
batch_first=True,
padding_value=0,
)
return padded_input_ids, padded_attention_mask, padded_reward_flags