def sympy_expr_eq()

in src/lighteval/metrics/utils/math_comparison.py [0:0]


def sympy_expr_eq(gold: Basic | MatrixBase, pred: Basic | MatrixBase, precision: int, strict: bool = True) -> bool:  # noqa: C901
    """Compare two sympy expressions for equality using multiple methods.

    Args:
        gold: First sympy expression (expected)
        pred: Second sympy expression (predicted)
        precision: Number of decimal places to compare
        strict: If true, variables do matter otherwise they don't

    Returns:
        True if expressions are equal by any comparison method, False otherwise
    """
    # This ensures that f(x) == f(y) is true
    if not strict:
        try:
            gold_variables = gold.free_symbols
            pred_variables = pred.free_symbols
            if len(gold_variables) == len(pred_variables):
                pred = pred.subs(list(zip(pred_variables, gold_variables)))
        except TimeoutError:
            raise
        except Exception:  # noqa: E722
            pass

    # If the target is relational, but the refernce is not, it's possible it's a case of a=x+1+z, so we just take x+1+z
    # We only do this if the lhs of the first equation is fully symbolic, to prevent simplifying x+y+2z = 1
    if is_assignment_relation(gold) and not is_equation(pred):
        gold = take_last_relation(gold).rhs

    # Here we respect the gold and simplify accordingly, thus any of
    # k=x+1+z or 1+1+1=3 will be simplified to rhs
    if is_equation(pred) and not is_equation(gold):
        pred = take_last_relation(pred).rhs

    if is_relation(gold) and isinstance(pred, Set):
        # This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo)
        # We also unwrap the functions because otherwise it creates some conditional set based on the function name
        try:
            gold = unwrap_fcs(gold).as_set()
        except TimeoutError:
            raise
        except Exception:  # noqa: E722
            pass

    # Start with simple str and expr comparison as it's the fastest
    # str comparison is better than simple eq, because it will also handle misarrangements
    if sympy_str_eq(gold, pred):
        return True

    # Support for equations
    if is_relation(gold) and is_relation(pred):
        return sympy_compare_relational(gold, pred, precision)

    elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)):
        return sympy_compare_sets(gold, pred, precision)

    # Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r)
    elif isinstance(gold, Symbol) or isinstance(pred, Symbol):
        return sympy_compare_symbols(gold, pred)

    elif isinstance(gold, (Basic, MatrixBase)) and isinstance(pred, (Basic, MatrixBase)):
        # Mostly so that 0.333333 = 1/3
        if sympy_numeric_eq(gold, pred, precision):
            return True
        # Then try symbolic equality
        if sympy_symbolic_eq(gold, pred):
            return True

    return False