def sympy_compare_symbols()

in src/lighteval/metrics/utils/math_comparison.py [0:0]


def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
    """Compare two sympy expressions where at least one is a Symbol.

    Handles special cases:
    - One is Symbol and other is E (limitation of parsed expressions)
    - One is multiplication of symbols and other is single symbol (concatenated comparison)
    """
    # Handle E vs symbol case
    if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
        isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
    ):
        return True

    # Handle multiplication of symbols vs single symbol
    if (
        isinstance(gold, Symbol)
        and isinstance(pred, Mul)
        and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
    ):
        concat_pred = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args)
        return gold.name.lower() == concat_pred.lower()

    if (
        isinstance(pred, Symbol)
        and isinstance(gold, Mul)
        and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
    ):
        concat_gold = "".join(arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args)
        return pred.name.lower() == concat_gold.lower()

    return gold == pred