def prepare_batch_input_for_model()

in src/sal/models/skywork_o1_prm/io_utils.py [0:0]


def prepare_batch_input_for_model(input_ids, reward_flags, pad_token_id):
    padded_input_ids = torch.nn.utils.rnn.pad_sequence(
        [torch.LongTensor(ids) for ids in input_ids],
        batch_first=True,
        padding_value=pad_token_id,
    )
    padded_attention_mask = torch.nn.utils.rnn.pad_sequence(
        [torch.LongTensor([1] * len(ids)) for ids in input_ids],
        batch_first=True,
        padding_value=0,
    )
    padded_reward_flags = torch.nn.utils.rnn.pad_sequence(
        [torch.LongTensor(reward_flag) for reward_flag in reward_flags],
        batch_first=True,
        padding_value=0,
    )
    return padded_input_ids, padded_attention_mask, padded_reward_flags