evaluate_model_outputs.py (101 lines of code) (raw):

import argparse import pandas as pd from typing import Any from math_verify.metric import math_metric from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig import sympy def parse_args(): parser = argparse.ArgumentParser(description='Extract and evaluate answers using sympy') parser.add_argument('--input_csv', type=str, required=True, help='Path to input CSV file containing model outputs') parser.add_argument('--output_csv', type=str, required=True, help='Path to output CSV file for extracted answers') parser.add_argument('--gold_is_latex', action='store_true', help='Use basic latex normalization', default=True) return parser.parse_args() def load_csv_data(csv_path: str) -> pd.DataFrame: """Load and validate CSV data.""" try: df = pd.read_csv(csv_path) required_columns = ['answer', 'gold'] if not all(col in df.columns for col in required_columns): raise ValueError(f"CSV must contain columns: {required_columns}") return df except Exception as e: raise Exception(f"Error loading CSV file: {str(e)}") def serialize_sympy_object(obj: Any) -> str: """Convert sympy object to string representation.""" if obj is None: return "" try: if isinstance(obj, (list, tuple)): return ", ".join(str(x) if x is not None else "" for x in obj) return str(obj) except Exception as e: return f"Error: {str(e)}" def compare_answers(extracted: Any, gold: Any) -> bool: """Compare extracted answer with gold answer.""" if extracted is None or gold is None: return False try: # Handle lists/tuples of expressions if isinstance(extracted, (list, tuple)) and isinstance(gold, (list, tuple)): if len(extracted) != len(gold): return False return all(sympy.simplify(a - b) == 0 for a, b in zip(extracted, gold)) # Handle single expressions return sympy.simplify(extracted - gold) == 0 except Exception: # If comparison fails (e.g. different types), return False return False def process_answers(df: pd.DataFrame, gold_is_latex: bool) -> pd.DataFrame: """Process each answer through the sympy extraction workflow and compare with gold using math_verify.""" results = [] correct_count = 0 total_count = 0 # Create the verification function verify_func = math_metric( gold_extraction_target=(LatexExtractionConfig() if gold_is_latex else ExprExtractionConfig(),), pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), aggregation_function=max, precision=6 ) for _, row in df.iterrows(): extracted_answers = None gold_answers = None grade = 0 try: # Use the verification function grade, extracted_answers = verify_func([row['gold']], [row['answer']]) if extracted_answers is None: extracted_answers = None gold_answers = None else: gold_answers = extracted_answers[0] extracted_answers = extracted_answers[1] total_count += 1 if grade == 1: correct_count += 1 result = { 'original_answer': row['answer'], 'gold_answer': row['gold'], 'extracted_answer': extracted_answers, 'extracted_gold': gold_answers, 'is_correct': grade == 1 } results.append(result) except Exception as e: results.append({ 'original_answer': row['answer'], 'gold_answer': row['gold'], 'extracted_answer': extracted_answers, 'extracted_gold': gold_answers, 'is_correct': grade == 1, 'error': str(e) }) results_df = pd.DataFrame(results) # Calculate accuracy accuracy = correct_count / total_count if total_count > 0 else 0 print(f"\nEvaluation Results:") print(f"Total examples: {total_count}") print(f"Correct answers: {correct_count}") print(f"Accuracy: {accuracy:.2%}") # Add summary stats to the dataframe results_df.attrs['accuracy'] = accuracy results_df.attrs['total_count'] = total_count results_df.attrs['correct_count'] = correct_count return results_df def main(): args = parse_args() # Load input CSV input_df = load_csv_data(args.input_csv) # Process answers and extract sympy objects results_df = process_answers(input_df, args.gold_is_latex) # Save results to output CSV results_df.to_csv(args.output_csv, index=False) print(f"\nResults saved to {args.output_csv}") if __name__ == "__main__": main()