in src/sal/models/skywork_o1_prm/io_utils.py [0:0]
def prepare_input(problem, response, tokenizer, step_token):
prompt_ids = tokenizer.encode(tokenizer.bos_token + problem + "\n")
response_ids = []
steps = []
reward_flags = [0] * len(prompt_ids)
step_token_id = tokenizer.encode(step_token)[-1]
for idx, step in enumerate(response.split(step_token)):
if step != "":
step_ids = tokenizer.encode(step)
else:
step_ids = []
step_ids += [step_token_id]
step = step + step_token
flag = [0] * len(step_ids)
flag[-1] = 1
response_ids.extend(step_ids)
reward_flags.extend(flag)
steps.append(step)
input_ids = prompt_ids + response_ids
return input_ids, steps, reward_flags