def sympy_compare_symbols()

in src/math_verify/grader.py [0:0]


def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
    """Compare two sympy expressions where at least one is a Symbol.

    Handles special cases:
    - One is Symbol and other is E (limitation of parsed expressions)
    - One is multiplication of symbols and other is single symbol (concatenated comparison)

    Args:
        gold: First sympy expression (expected)
        pred: Second sympy expression (predicted)
        precision: Number of decimal places to compare

    Returns:
        True if expressions are equal by any comparison method, False otherwise
    """
    # Handle E vs symbol case
    if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
        isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
    ):
        return True

    # Handle multiplication of symbols vs single symbol, because parsing return $abc$ -> abc
    # We also handle E as it's a symbol, because E will be always parsed as exp
    if (
        isinstance(gold, Symbol)
        and isinstance(pred, Mul)
        and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
    ):
        concat_pred = "".join(
            arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args
        )
        return gold.name.lower() == concat_pred.lower()

    if (
        isinstance(pred, Symbol)
        and isinstance(gold, Mul)
        and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
    ):
        concat_gold = "".join(
            arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args
        )
        return pred.name.lower() == concat_gold.lower()

    # Simple
    if isinstance(gold, Symbol) and isinstance(pred, Symbol):
        g_name = gold.name
        p_name = pred.name
        if len(p_name) > 1:
            p_name = p_name.lower()
        if len(g_name) > 1:
            g_name = g_name.lower()
        return g_name == p_name

    return False