fn expr_to_sql_ok()

in datafusion/sql/src/unparser/expr.rs [1790:2237]


    fn expr_to_sql_ok() -> Result<()> {
        let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
        #[expect(deprecated)]
        let dummy_logical_plan = table_scan(Some("t"), &dummy_schema, None)?
            .project(vec![Expr::Wildcard {
                qualifier: None,
                options: Box::new(WildcardOptions::default()),
            }])?
            .filter(col("a").eq(lit(1)))?
            .build()?;

        let tests: Vec<(Expr, &str)> = vec![
            ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
            (
                Expr::Column(Column {
                    relation: Some(TableReference::partial("a", "b")),
                    name: "c".to_string(),
                    spans: Spans::new(),
                })
                .gt(lit(4)),
                r#"(b.c > 4)"#,
            ),
            (
                case(col("a"))
                    .when(lit(1), lit(true))
                    .when(lit(0), lit(false))
                    .otherwise(lit(ScalarValue::Null))?,
                r#"CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#,
            ),
            (
                when(col("a").is_null(), lit(true)).otherwise(lit(false))?,
                r#"CASE WHEN a IS NULL THEN true ELSE false END"#,
            ),
            (
                when(col("a").is_not_null(), lit(true)).otherwise(lit(false))?,
                r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#,
            ),
            (
                Expr::Cast(Cast {
                    expr: Box::new(col("a")),
                    data_type: DataType::Date64,
                }),
                r#"CAST(a AS DATETIME)"#,
            ),
            (
                Expr::Cast(Cast {
                    expr: Box::new(col("a")),
                    data_type: DataType::Timestamp(
                        TimeUnit::Nanosecond,
                        Some("+08:00".into()),
                    ),
                }),
                r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#,
            ),
            (
                Expr::Cast(Cast {
                    expr: Box::new(col("a")),
                    data_type: DataType::Timestamp(TimeUnit::Millisecond, None),
                }),
                r#"CAST(a AS TIMESTAMP)"#,
            ),
            (
                Expr::Cast(Cast {
                    expr: Box::new(col("a")),
                    data_type: DataType::UInt32,
                }),
                r#"CAST(a AS INTEGER UNSIGNED)"#,
            ),
            (
                col("a").in_list(vec![lit(1), lit(2), lit(3)], false),
                r#"a IN (1, 2, 3)"#,
            ),
            (
                col("a").in_list(vec![lit(1), lit(2), lit(3)], true),
                r#"a NOT IN (1, 2, 3)"#,
            ),
            (
                ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]),
                r#"dummy_udf(a, b)"#,
            ),
            (
                ScalarUDF::new_from_impl(DummyUDF::new())
                    .call(vec![col("a"), col("b")])
                    .is_null(),
                r#"dummy_udf(a, b) IS NULL"#,
            ),
            (
                ScalarUDF::new_from_impl(DummyUDF::new())
                    .call(vec![col("a"), col("b")])
                    .is_not_null(),
                r#"dummy_udf(a, b) IS NOT NULL"#,
            ),
            (
                Expr::Like(Like {
                    negated: true,
                    expr: Box::new(col("a")),
                    pattern: Box::new(lit("foo")),
                    escape_char: Some('o'),
                    case_insensitive: false,
                }),
                r#"a NOT LIKE 'foo' ESCAPE 'o'"#,
            ),
            (
                Expr::Like(Like {
                    negated: true,
                    expr: Box::new(col("a")),
                    pattern: Box::new(lit("foo")),
                    escape_char: Some('o'),
                    case_insensitive: true,
                }),
                r#"a NOT ILIKE 'foo' ESCAPE 'o'"#,
            ),
            (
                Expr::SimilarTo(Like {
                    negated: false,
                    expr: Box::new(col("a")),
                    pattern: Box::new(lit("foo")),
                    escape_char: Some('o'),
                    case_insensitive: true,
                }),
                r#"a LIKE 'foo' ESCAPE 'o'"#,
            ),
            (
                Expr::Literal(ScalarValue::Date64(Some(0))),
                r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Date64(Some(10000))),
                r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Date64(Some(-10000))),
                r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Date32(Some(0))),
                r#"CAST('1970-01-01' AS DATE)"#,
            ),
            (
                Expr::Literal(ScalarValue::Date32(Some(10))),
                r#"CAST('1970-01-11' AS DATE)"#,
            ),
            (
                Expr::Literal(ScalarValue::Date32(Some(-1))),
                r#"CAST('1969-12-31' AS DATE)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)),
                r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampSecond(
                    Some(10001),
                    Some("+08:00".into()),
                )),
                r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)),
                r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampMillisecond(
                    Some(10001),
                    Some("+08:00".into()),
                )),
                r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)),
                r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampMicrosecond(
                    Some(10001),
                    Some("+08:00".into()),
                )),
                r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)),
                r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::TimestampNanosecond(
                    Some(10001),
                    Some("+08:00".into()),
                )),
                r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#,
            ),
            (
                Expr::Literal(ScalarValue::Time32Second(Some(10001))),
                r#"CAST('02:46:41' AS TIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))),
                r#"CAST('00:00:10.001' AS TIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))),
                r#"CAST('00:00:00.010001' AS TIME)"#,
            ),
            (
                Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))),
                r#"CAST('00:00:00.000010001' AS TIME)"#,
            ),
            (sum(col("a")), r#"sum(a)"#),
            (
                #[expect(deprecated)]
                count_udaf()
                    .call(vec![Expr::Wildcard {
                        qualifier: None,
                        options: Box::new(WildcardOptions::default()),
                    }])
                    .distinct()
                    .build()
                    .unwrap(),
                "count(DISTINCT *)",
            ),
            (
                #[expect(deprecated)]
                count_udaf()
                    .call(vec![Expr::Wildcard {
                        qualifier: None,
                        options: Box::new(WildcardOptions::default()),
                    }])
                    .filter(lit(true))
                    .build()
                    .unwrap(),
                "count(*) FILTER (WHERE true)",
            ),
            (
                Expr::WindowFunction(WindowFunction {
                    fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()),
                    params: WindowFunctionParams {
                        args: vec![col("col")],
                        partition_by: vec![],
                        order_by: vec![],
                        window_frame: WindowFrame::new(None),
                        null_treatment: None,
                    },
                }),
                r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#,
            ),
            (
                #[expect(deprecated)]
                Expr::WindowFunction(WindowFunction {
                    fun: WindowFunctionDefinition::AggregateUDF(count_udaf()),
                    params: WindowFunctionParams {
                        args: vec![Expr::Wildcard {
                            qualifier: None,
                            options: Box::new(WildcardOptions::default()),
                        }],
                        partition_by: vec![],
                        order_by: vec![Sort::new(col("a"), false, true)],
                        window_frame: WindowFrame::new_bounds(
                            datafusion_expr::WindowFrameUnits::Range,
                            datafusion_expr::WindowFrameBound::Preceding(
                                ScalarValue::UInt32(Some(6)),
                            ),
                            datafusion_expr::WindowFrameBound::Following(
                                ScalarValue::UInt32(Some(2)),
                            ),
                        ),
                        null_treatment: None,
                    },
                }),
                r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#,
            ),
            (col("a").is_not_null(), r#"a IS NOT NULL"#),
            (col("a").is_null(), r#"a IS NULL"#),
            (
                (col("a") + col("b")).gt(lit(4)).is_true(),
                r#"((a + b) > 4) IS TRUE"#,
            ),
            (
                (col("a") + col("b")).gt(lit(4)).is_not_true(),
                r#"((a + b) > 4) IS NOT TRUE"#,
            ),
            (
                (col("a") + col("b")).gt(lit(4)).is_false(),
                r#"((a + b) > 4) IS FALSE"#,
            ),
            (
                (col("a") + col("b")).gt(lit(4)).is_not_false(),
                r#"((a + b) > 4) IS NOT FALSE"#,
            ),
            (
                (col("a") + col("b")).gt(lit(4)).is_unknown(),
                r#"((a + b) > 4) IS UNKNOWN"#,
            ),
            (
                (col("a") + col("b")).gt(lit(4)).is_not_unknown(),
                r#"((a + b) > 4) IS NOT UNKNOWN"#,
            ),
            (not(col("a")), r#"NOT a"#),
            (
                Expr::between(col("a"), lit(1), lit(7)),
                r#"(a BETWEEN 1 AND 7)"#,
            ),
            (Expr::Negative(Box::new(col("a"))), r#"-a"#),
            (
                exists(Arc::new(dummy_logical_plan.clone())),
                r#"EXISTS (SELECT * FROM t WHERE (t.a = 1))"#,
            ),
            (
                not_exists(Arc::new(dummy_logical_plan)),
                r#"NOT EXISTS (SELECT * FROM t WHERE (t.a = 1))"#,
            ),
            (
                try_cast(col("a"), DataType::Date64),
                r#"TRY_CAST(a AS DATETIME)"#,
            ),
            (
                try_cast(col("a"), DataType::UInt32),
                r#"TRY_CAST(a AS INTEGER UNSIGNED)"#,
            ),
            (
                Expr::ScalarVariable(Int8, vec![String::from("@a")]),
                r#"@a"#,
            ),
            (
                Expr::ScalarVariable(
                    Int8,
                    vec![String::from("@root"), String::from("foo")],
                ),
                r#"@root.foo"#,
            ),
            (col("x").eq(placeholder("$1")), r#"(x = $1)"#),
            (
                out_ref_col(DataType::Int32, "t.a").gt(lit(1)),
                r#"(t.a > 1)"#,
            ),
            (
                grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]),
                r#"GROUPING SETS ((a, b), (a))"#,
            ),
            (cube(vec![col("a"), col("b")]), r#"CUBE (a, b)"#),
            (rollup(vec![col("a"), col("b")]), r#"ROLLUP (a, b)"#),
            (col("table").eq(lit(1)), r#"("table" = 1)"#),
            (
                col("123_need_quoted").eq(lit(1)),
                r#"("123_need_quoted" = 1)"#,
            ),
            (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#),
            (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#),
            // See test_interval_scalar_to_expr for interval literals
            (
                (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128(
                    Some(100123),
                    28,
                    3,
                ))),
                r#"((a + b) > 100.123)"#,
            ),
            (
                (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256(
                    Some(100123.into()),
                    28,
                    3,
                ))),
                r#"((a + b) > 100.123)"#,
            ),
            (
                Expr::Cast(Cast {
                    expr: Box::new(col("a")),
                    data_type: DataType::Decimal128(10, -2),
                }),
                r#"CAST(a AS DECIMAL(12,0))"#,
            ),
            (
                Expr::Unnest(Unnest {
                    expr: Box::new(Expr::Column(Column {
                        relation: Some(TableReference::partial("schema", "table")),
                        name: "array_col".to_string(),
                        spans: Spans::new(),
                    })),
                }),
                r#"UNNEST("table".array_col)"#,
            ),
            (make_array(vec![lit(1), lit(2), lit(3)]), "[1, 2, 3]"),
            (array_element(col("array_col"), lit(1)), "array_col[1]"),
            (
                array_element(make_array(vec![lit(1), lit(2), lit(3)]), lit(1)),
                "[1, 2, 3][1]",
            ),
            (
                named_struct(vec![lit("a"), lit("1"), lit("b"), lit(2)]),
                "{a: '1', b: 2}",
            ),
            (get_field(col("a.b"), "c"), "a.b.c"),
            (
                map(vec![lit("a"), lit("b")], vec![lit(1), lit(2)]),
                "MAP {'a': 1, 'b': 2}",
            ),
            (
                Expr::Literal(ScalarValue::Dictionary(
                    Box::new(DataType::Int32),
                    Box::new(ScalarValue::Utf8(Some("foo".into()))),
                )),
                "'foo'",
            ),
            (
                Expr::Literal(ScalarValue::List(Arc::new(
                    ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
                        Some(1),
                        Some(2),
                        Some(3),
                    ])]),
                ))),
                "[1, 2, 3]",
            ),
            (
                Expr::Literal(ScalarValue::LargeList(Arc::new(
                    LargeListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(
                        vec![Some(1), Some(2), Some(3)],
                    )]),
                ))),
                "[1, 2, 3]",
            ),
            (
                Expr::BinaryExpr(BinaryExpr {
                    left: Box::new(col("a")),
                    op: Operator::ArrowAt,
                    right: Box::new(col("b")),
                }),
                "(a <@ b)",
            ),
            (
                Expr::BinaryExpr(BinaryExpr {
                    left: Box::new(col("a")),
                    op: Operator::AtArrow,
                    right: Box::new(col("b")),
                }),
                "(a @> b)",
            ),
        ];

        for (expr, expected) in tests {
            let ast = expr_to_sql(&expr)?;

            let actual = format!("{}", ast);

            assert_eq!(actual, expected);
        }

        Ok(())
    }