def main()

in code/run_eval_prm_rlhflow.py [0:0]


def main():
    client = OpenAI(
        base_url="http://localhost:8000/v1",
        api_key="token-abc123",
    )
    os.makedirs('outputs/Llama3.1-8B-PRM-Mistral-Data', exist_ok=True)

    def single_process(d):
        steps = d['steps']
        messages = []
        for sdx, step in enumerate(steps):
            if sdx == 0:
                messages.append({'role': 'user', 'content': d['problem'] + '\n\n' + step})
            else:
                messages.append({'role': 'user', 'content': step})
            completion = client.chat.completions.create(
                model='Llama3.1-8B-PRM-Mistral-Data',
                messages=messages,
                n=1,
                temperature=0.,
                max_tokens=1,
            )
            judgment = completion.choices[0].message.content.strip().lower().startswith('+')
            if not judgment:
                return sdx
            messages.append({'role': 'assistant', 'content': '+'})
        return -1

    configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']
    for config in configs:
        input_data = load_dataset('/cpfs01/user/zhengchujie.zcj/hf_datasets/Qwen/ProcessBench', split=config)
        with Pool(32) as p:
            predictions = list(tqdm(p.imap(single_process, input_data), total=len(input_data),
                                    desc=f'Processing {config}', dynamic_ncols=True))
        
        res_data = []
        for idx, d in enumerate(input_data):
            new_d = d.copy()
            new_d['prediction'] = predictions[idx]
            new_d['match'] = predictions[idx] == d['label']
            res_data.append(new_d)
        
        data1 = [e for e in res_data if e['label'] != -1]
        data2 = [e for e in res_data if e['label'] == -1]
        with open(f'outputs/Llama3.1-8B-PRM-Mistral-Data/{config}_error.jsonl', 'w') as f:
            for e in data1:
                f.write(json.dumps(e) + '\n')
        with open(f'outputs/Llama3.1-8B-PRM-Mistral-Data/{config}_correct.jsonl', 'w') as f:
            for e in data2:
                f.write(json.dumps(e) + '\n')
        
        acc1 = np.mean([e['match'] for e in data1]) * 100
        acc2 = np.mean([e['match'] for e in data2]) * 100
        f1 = 2 * acc1 * acc2 / (acc1 + acc2)
        print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')