in src/lighteval/metrics/utils/math_comparison.py [0:0]
def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
"""Compare two sympy expressions where at least one is a Symbol.
Handles special cases:
- One is Symbol and other is E (limitation of parsed expressions)
- One is multiplication of symbols and other is single symbol (concatenated comparison)
"""
# Handle E vs symbol case
if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
):
return True
# Handle multiplication of symbols vs single symbol
if (
isinstance(gold, Symbol)
and isinstance(pred, Mul)
and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
):
concat_pred = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args)
return gold.name.lower() == concat_pred.lower()
if (
isinstance(pred, Symbol)
and isinstance(gold, Mul)
and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
):
concat_gold = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args)
return pred.name.lower() == concat_gold.lower()
return gold == pred