in src/math_verify/grader.py [0:0]
def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
"""Compare two sympy expressions where at least one is a Symbol.
Handles special cases:
- One is Symbol and other is E (limitation of parsed expressions)
- One is multiplication of symbols and other is single symbol (concatenated comparison)
Args:
gold: First sympy expression (expected)
pred: Second sympy expression (predicted)
precision: Number of decimal places to compare
Returns:
True if expressions are equal by any comparison method, False otherwise
"""
# Handle E vs symbol case
if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
):
return True
# Handle multiplication of symbols vs single symbol, because parsing return $abc$ -> abc
# We also handle E as it's a symbol, because E will be always parsed as exp
if (
isinstance(gold, Symbol)
and isinstance(pred, Mul)
and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
):
concat_pred = "".join(
arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args
)
return gold.name.lower() == concat_pred.lower()
if (
isinstance(pred, Symbol)
and isinstance(gold, Mul)
and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
):
concat_gold = "".join(
arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args
)
return pred.name.lower() == concat_gold.lower()
# Simple
if isinstance(gold, Symbol) and isinstance(pred, Symbol):
g_name = gold.name
p_name = pred.name
if len(p_name) > 1:
p_name = p_name.lower()
if len(g_name) > 1:
g_name = g_name.lower()
return g_name == p_name
return False