in src/lighteval/metrics/utils/math_comparison.py [0:0]
def sympy_expr_eq(gold: Basic | MatrixBase, pred: Basic | MatrixBase, precision: int, strict: bool = True) -> bool: # noqa: C901
"""Compare two sympy expressions for equality using multiple methods.
Args:
gold: First sympy expression (expected)
pred: Second sympy expression (predicted)
precision: Number of decimal places to compare
strict: If true, variables do matter otherwise they don't
Returns:
True if expressions are equal by any comparison method, False otherwise
"""
# This ensures that f(x) == f(y) is true
if not strict:
try:
gold_variables = gold.free_symbols
pred_variables = pred.free_symbols
if len(gold_variables) == len(pred_variables):
pred = pred.subs(list(zip(pred_variables, gold_variables)))
except TimeoutError:
raise
except Exception: # noqa: E722
pass
# If the target is relational, but the refernce is not, it's possible it's a case of a=x+1+z, so we just take x+1+z
# We only do this if the lhs of the first equation is fully symbolic, to prevent simplifying x+y+2z = 1
if is_assignment_relation(gold) and not is_equation(pred):
gold = take_last_relation(gold).rhs
# Here we respect the gold and simplify accordingly, thus any of
# k=x+1+z or 1+1+1=3 will be simplified to rhs
if is_equation(pred) and not is_equation(gold):
pred = take_last_relation(pred).rhs
if is_relation(gold) and isinstance(pred, Set):
# This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo)
# We also unwrap the functions because otherwise it creates some conditional set based on the function name
try:
gold = unwrap_fcs(gold).as_set()
except TimeoutError:
raise
except Exception: # noqa: E722
pass
# Start with simple str and expr comparison as it's the fastest
# str comparison is better than simple eq, because it will also handle misarrangements
if sympy_str_eq(gold, pred):
return True
# Support for equations
if is_relation(gold) and is_relation(pred):
return sympy_compare_relational(gold, pred, precision)
elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)):
return sympy_compare_sets(gold, pred, precision)
# Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r)
elif isinstance(gold, Symbol) or isinstance(pred, Symbol):
return sympy_compare_symbols(gold, pred)
elif isinstance(gold, (Basic, MatrixBase)) and isinstance(pred, (Basic, MatrixBase)):
# Mostly so that 0.333333 = 1/3
if sympy_numeric_eq(gold, pred, precision):
return True
# Then try symbolic equality
if sympy_symbolic_eq(gold, pred):
return True
return False