in code/run_eval.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--configs', type=str, nargs='+', default=None,
choices=['gsm8k', 'math', 'olympiadbench', 'omnimath'])
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument("--output_dir", type=str, default='./outputs')
parser.add_argument('--use_voting', action='store_true')
parser.add_argument('--voting_n', type=int, default=8)
args = parser.parse_args()
args.model_name = os.path.basename(args.model_path)
toker = AutoTokenizer.from_pretrained(args.model_path)
TEMPLATE = open('./templates/critique_template.txt').read().strip()
llm = LLM(
model=args.model_path, tokenizer=args.model_path,
gpu_memory_utilization=0.95,
tensor_parallel_size=torch.cuda.device_count(),
enable_prefix_caching=True, swap_space=16,
max_num_seqs=20,
)
if not args.use_voting:
sampling_params = SamplingParams(temperature=0.,
max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
else:
if 'Qwen2.5-Math' in args.model_path: # to ensure normal generation of Qwen2.5-Math-7B/72B-Instruct
sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, n=args.voting_n,
max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
else:
sampling_params = SamplingParams(temperature=1, top_p=0.9, n=args.voting_n,
max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
if args.configs is None:
args.configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']
for config in args.configs:
if not args.use_voting:
output_dir = os.path.join(args.output_dir, args.model_name)
else:
output_dir = os.path.join(args.output_dir, f'{args.model_name}_voting')
os.makedirs(output_dir, exist_ok=True)
input_data = load_dataset('Qwen/ProcessBench', split=config)
prompt_token_ids = [apply_chat_template(toker, prepare_input_boxed(TEMPLATE, e))
for e in input_data]
generations = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
res_data = []
for i in range(len(input_data)):
d = input_data[i].copy()
if not args.use_voting:
generated_critique = generations[i].outputs[0].text
pred = extract_answer(generated_critique)
try:
pred = int(pred)
except:
pred = None
else:
generated_critique = [ee.text for ee in generations[i].outputs]
preds = [extract_answer(e) for e in generated_critique]
preds = [e for e in preds if e is not None]
if len(preds) == 0:
pred = None
else:
pred = Counter(preds).most_common(1)[0][0]
try:
pred = int(pred)
except:
pred = None
d['generated_critique'] = generated_critique
d['prediction'] = pred
d['match'] = (pred == d['label'])
res_data.append(d)
error_data = [e for e in res_data if e['label'] != -1]
correct_data = [e for e in res_data if e['label'] == -1]
with open(os.path.join(output_dir, f'{config}_error.jsonl'), 'w') as f:
for e in error_data:
f.write(json.dumps(e) + '\n')
with open(os.path.join(output_dir, f'{config}_correct.jsonl'), 'w') as f:
for e in correct_data:
f.write(json.dumps(e) + '\n')
acc1 = np.mean([e['match'] for e in error_data]) * 100
acc2 = np.mean([e['match'] for e in correct_data]) * 100
f1 = 2 * acc1 * acc2 / (acc1 + acc2)
print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')