fn try_into_logical_plan()

in datafusion/proto/src/logical_plan/mod.rs [299:998]


    fn try_into_logical_plan(
        &self,
        ctx: &SessionContext,
        extension_codec: &dyn LogicalExtensionCodec,
    ) -> Result<LogicalPlan> {
        let plan = self.logical_plan_type.as_ref().ok_or_else(|| {
            proto_error(format!(
                "logical_plan::from_proto() Unsupported logical plan '{self:?}'"
            ))
        })?;
        match plan {
            LogicalPlanType::Values(values) => {
                let n_cols = values.n_cols as usize;
                let values: Vec<Vec<Expr>> = if values.values_list.is_empty() {
                    Ok(Vec::new())
                } else if values.values_list.len() % n_cols != 0 {
                    internal_err!(
                        "Invalid values list length, expect {} to be divisible by {}",
                        values.values_list.len(),
                        n_cols
                    )
                } else {
                    values
                        .values_list
                        .chunks_exact(n_cols)
                        .map(|r| from_proto::parse_exprs(r, ctx, extension_codec))
                        .collect::<Result<Vec<_>, _>>()
                        .map_err(|e| e.into())
                }?;

                LogicalPlanBuilder::values(values)?.build()
            }
            LogicalPlanType::Projection(projection) => {
                let input: LogicalPlan =
                    into_logical_plan!(projection.input, ctx, extension_codec)?;
                let expr: Vec<Expr> =
                    from_proto::parse_exprs(&projection.expr, ctx, extension_codec)?;

                let new_proj = project(input, expr)?;
                match projection.optional_alias.as_ref() {
                    Some(a) => match a {
                        protobuf::projection_node::OptionalAlias::Alias(alias) => {
                            Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
                                Arc::new(new_proj),
                                alias.clone(),
                            )?))
                        }
                    },
                    _ => Ok(new_proj),
                }
            }
            LogicalPlanType::Selection(selection) => {
                let input: LogicalPlan =
                    into_logical_plan!(selection.input, ctx, extension_codec)?;
                let expr: Expr = selection
                    .expr
                    .as_ref()
                    .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec))
                    .transpose()?
                    .ok_or_else(|| proto_error("expression required"))?;
                LogicalPlanBuilder::from(input).filter(expr)?.build()
            }
            LogicalPlanType::Window(window) => {
                let input: LogicalPlan =
                    into_logical_plan!(window.input, ctx, extension_codec)?;
                let window_expr =
                    from_proto::parse_exprs(&window.window_expr, ctx, extension_codec)?;
                LogicalPlanBuilder::from(input).window(window_expr)?.build()
            }
            LogicalPlanType::Aggregate(aggregate) => {
                let input: LogicalPlan =
                    into_logical_plan!(aggregate.input, ctx, extension_codec)?;
                let group_expr =
                    from_proto::parse_exprs(&aggregate.group_expr, ctx, extension_codec)?;
                let aggr_expr =
                    from_proto::parse_exprs(&aggregate.aggr_expr, ctx, extension_codec)?;
                LogicalPlanBuilder::from(input)
                    .aggregate(group_expr, aggr_expr)?
                    .build()
            }
            LogicalPlanType::ListingScan(scan) => {
                let schema: Schema = convert_required!(scan.schema)?;

                let mut projection = None;
                if let Some(columns) = &scan.projection {
                    let column_indices = columns
                        .columns
                        .iter()
                        .map(|name| schema.index_of(name))
                        .collect::<Result<Vec<usize>, _>>()?;
                    projection = Some(column_indices);
                }

                let filters =
                    from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?;

                let mut all_sort_orders = vec![];
                for order in &scan.file_sort_order {
                    all_sort_orders.push(from_proto::parse_sorts(
                        &order.sort_expr_nodes,
                        ctx,
                        extension_codec,
                    )?)
                }

                let file_format: Arc<dyn FileFormat> =
                    match scan.file_format_type.as_ref().ok_or_else(|| {
                        proto_error(format!(
                            "logical_plan::from_proto() Unsupported file format '{self:?}'"
                        ))
                    })? {
                        #[cfg_attr(not(feature = "parquet"), allow(unused_variables))]
                        FileFormatType::Parquet(protobuf::ParquetFormat {options}) => {
                            #[cfg(feature = "parquet")]
                            {
                                let mut parquet = ParquetFormat::default();
                                if let Some(options) = options {
                                    parquet = parquet.with_options(options.try_into()?)
                                }
                                Arc::new(parquet)
                            }
                            #[cfg(not(feature = "parquet"))]
                            panic!("Unable to process parquet file since `parquet` feature is not enabled");
                        }
                        FileFormatType::Csv(protobuf::CsvFormat {
                            options
                        }) => {
                            let mut csv = CsvFormat::default();
                            if let Some(options) = options {
                                csv = csv.with_options(options.try_into()?)
                            }
                            Arc::new(csv)
                        },
                        FileFormatType::Json(protobuf::NdJsonFormat {
                            options
                        }) => {
                            let mut json = OtherNdJsonFormat::default();
                            if let Some(options) = options {
                                json = json.with_options(options.try_into()?)
                            }
                            Arc::new(json)
                        }
                        #[cfg_attr(not(feature = "avro"), allow(unused_variables))]
                        FileFormatType::Avro(..) => {
                            #[cfg(feature = "avro")] 
                            {
                                Arc::new(AvroFormat)
                            }
                            #[cfg(not(feature = "avro"))]
                            panic!("Unable to process avro file since `avro` feature is not enabled");
                        }
                    };

                let table_paths = &scan
                    .paths
                    .iter()
                    .map(ListingTableUrl::parse)
                    .collect::<Result<Vec<_>, _>>()?;

                let partition_columns = scan
                    .table_partition_cols
                    .iter()
                    .map(|col| {
                        let Some(arrow_type) = col.arrow_type.as_ref() else {
                            return Err(proto_error(
                                "Missing Arrow type in partition columns",
                            ));
                        };
                        let arrow_type = DataType::try_from(arrow_type).map_err(|e| {
                            proto_error(format!("Received an unknown ArrowType: {}", e))
                        })?;
                        Ok((col.name.clone(), arrow_type))
                    })
                    .collect::<Result<Vec<_>>>()?;

                let options = ListingOptions::new(file_format)
                    .with_file_extension(&scan.file_extension)
                    .with_table_partition_cols(partition_columns)
                    .with_collect_stat(scan.collect_stat)
                    .with_target_partitions(scan.target_partitions as usize)
                    .with_file_sort_order(all_sort_orders);

                let config =
                    ListingTableConfig::new_with_multi_paths(table_paths.clone())
                        .with_listing_options(options)
                        .with_schema(Arc::new(schema));

                let provider = ListingTable::try_new(config)?.with_cache(
                    ctx.state()
                        .runtime_env()
                        .cache_manager
                        .get_file_statistic_cache(),
                );

                let table_name =
                    from_table_reference(scan.table_name.as_ref(), "ListingTableScan")?;

                LogicalPlanBuilder::scan_with_filters(
                    table_name,
                    provider_as_source(Arc::new(provider)),
                    projection,
                    filters,
                )?
                .build()
            }
            LogicalPlanType::CustomScan(scan) => {
                let schema: Schema = convert_required!(scan.schema)?;
                let schema = Arc::new(schema);
                let mut projection = None;
                if let Some(columns) = &scan.projection {
                    let column_indices = columns
                        .columns
                        .iter()
                        .map(|name| schema.index_of(name))
                        .collect::<Result<Vec<usize>, _>>()?;
                    projection = Some(column_indices);
                }

                let filters =
                    from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?;

                let table_name =
                    from_table_reference(scan.table_name.as_ref(), "CustomScan")?;

                let provider = extension_codec.try_decode_table_provider(
                    &scan.custom_table_data,
                    &table_name,
                    schema,
                    ctx,
                )?;

                LogicalPlanBuilder::scan_with_filters(
                    table_name,
                    provider_as_source(provider),
                    projection,
                    filters,
                )?
                .build()
            }
            LogicalPlanType::Sort(sort) => {
                let input: LogicalPlan =
                    into_logical_plan!(sort.input, ctx, extension_codec)?;
                let sort_expr: Vec<SortExpr> =
                    from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?;
                let fetch: Option<usize> = sort.fetch.try_into().ok();
                LogicalPlanBuilder::from(input)
                    .sort_with_limit(sort_expr, fetch)?
                    .build()
            }
            LogicalPlanType::Repartition(repartition) => {
                use datafusion::logical_expr::Partitioning;
                let input: LogicalPlan =
                    into_logical_plan!(repartition.input, ctx, extension_codec)?;
                use protobuf::repartition_node::PartitionMethod;
                let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| {
                    internal_datafusion_err!(
                        "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'"
                    )
                })?;

                let partitioning_scheme = match pb_partition_method {
                    PartitionMethod::Hash(protobuf::HashRepartition {
                        hash_expr: pb_hash_expr,
                        partition_count,
                    }) => Partitioning::Hash(
                        from_proto::parse_exprs(pb_hash_expr, ctx, extension_codec)?,
                        *partition_count as usize,
                    ),
                    PartitionMethod::RoundRobin(partition_count) => {
                        Partitioning::RoundRobinBatch(*partition_count as usize)
                    }
                };

                LogicalPlanBuilder::from(input)
                    .repartition(partitioning_scheme)?
                    .build()
            }
            LogicalPlanType::EmptyRelation(empty_relation) => {
                LogicalPlanBuilder::empty(empty_relation.produce_one_row).build()
            }
            LogicalPlanType::CreateExternalTable(create_extern_table) => {
                let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| {
                    DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, CreateExternalTableNode was missing required field schema."
                    ))
                })?;

                let constraints = (create_extern_table.constraints.clone()).ok_or_else(|| {
                    DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, CreateExternalTableNode was missing required table constraints.",
                    ))
                })?;
                let definition = if !create_extern_table.definition.is_empty() {
                    Some(create_extern_table.definition.clone())
                } else {
                    None
                };

                let file_type = create_extern_table.file_type.as_str();
                if ctx.table_factory(file_type).is_none() {
                    internal_err!("No TableProviderFactory for file type: {file_type}")?
                }

                let mut order_exprs = vec![];
                for expr in &create_extern_table.order_exprs {
                    order_exprs.push(from_proto::parse_sorts(
                        &expr.sort_expr_nodes,
                        ctx,
                        extension_codec,
                    )?);
                }

                let mut column_defaults =
                    HashMap::with_capacity(create_extern_table.column_defaults.len());
                for (col_name, expr) in &create_extern_table.column_defaults {
                    let expr = from_proto::parse_expr(expr, ctx, extension_codec)?;
                    column_defaults.insert(col_name.clone(), expr);
                }

                Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(
                    CreateExternalTable {
                        schema: pb_schema.try_into()?,
                        name: from_table_reference(
                            create_extern_table.name.as_ref(),
                            "CreateExternalTable",
                        )?,
                        location: create_extern_table.location.clone(),
                        file_type: create_extern_table.file_type.clone(),
                        table_partition_cols: create_extern_table
                            .table_partition_cols
                            .clone(),
                        order_exprs,
                        if_not_exists: create_extern_table.if_not_exists,
                        temporary: create_extern_table.temporary,
                        definition,
                        unbounded: create_extern_table.unbounded,
                        options: create_extern_table.options.clone(),
                        constraints: constraints.into(),
                        column_defaults,
                    },
                )))
            }
            LogicalPlanType::CreateView(create_view) => {
                let plan = create_view
                    .input.clone().ok_or_else(|| DataFusionError::Internal(String::from(
                    "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input.",
                )))?
                    .try_into_logical_plan(ctx, extension_codec)?;
                let definition = if !create_view.definition.is_empty() {
                    Some(create_view.definition.clone())
                } else {
                    None
                };

                Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
                    name: from_table_reference(create_view.name.as_ref(), "CreateView")?,
                    temporary: create_view.temporary,
                    input: Arc::new(plan),
                    or_replace: create_view.or_replace,
                    definition,
                })))
            }
            LogicalPlanType::CreateCatalogSchema(create_catalog_schema) => {
                let pb_schema = (create_catalog_schema.schema.clone()).ok_or_else(|| {
                    DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, CreateCatalogSchemaNode was missing required field schema.",
                    ))
                })?;

                Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(
                    CreateCatalogSchema {
                        schema_name: create_catalog_schema.schema_name.clone(),
                        if_not_exists: create_catalog_schema.if_not_exists,
                        schema: pb_schema.try_into()?,
                    },
                )))
            }
            LogicalPlanType::CreateCatalog(create_catalog) => {
                let pb_schema = (create_catalog.schema.clone()).ok_or_else(|| {
                    DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, CreateCatalogNode was missing required field schema.",
                    ))
                })?;

                Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog(
                    CreateCatalog {
                        catalog_name: create_catalog.catalog_name.clone(),
                        if_not_exists: create_catalog.if_not_exists,
                        schema: pb_schema.try_into()?,
                    },
                )))
            }
            LogicalPlanType::Analyze(analyze) => {
                let input: LogicalPlan =
                    into_logical_plan!(analyze.input, ctx, extension_codec)?;
                LogicalPlanBuilder::from(input)
                    .explain(analyze.verbose, true)?
                    .build()
            }
            LogicalPlanType::Explain(explain) => {
                let input: LogicalPlan =
                    into_logical_plan!(explain.input, ctx, extension_codec)?;
                LogicalPlanBuilder::from(input)
                    .explain(explain.verbose, false)?
                    .build()
            }
            LogicalPlanType::SubqueryAlias(aliased_relation) => {
                let input: LogicalPlan =
                    into_logical_plan!(aliased_relation.input, ctx, extension_codec)?;
                let alias = from_table_reference(
                    aliased_relation.alias.as_ref(),
                    "SubqueryAlias",
                )?;
                LogicalPlanBuilder::from(input).alias(alias)?.build()
            }
            LogicalPlanType::Limit(limit) => {
                let input: LogicalPlan =
                    into_logical_plan!(limit.input, ctx, extension_codec)?;
                let skip = limit.skip.max(0) as usize;

                let fetch = if limit.fetch < 0 {
                    None
                } else {
                    Some(limit.fetch as usize)
                };

                LogicalPlanBuilder::from(input).limit(skip, fetch)?.build()
            }
            LogicalPlanType::Join(join) => {
                let left_keys: Vec<Expr> =
                    from_proto::parse_exprs(&join.left_join_key, ctx, extension_codec)?;
                let right_keys: Vec<Expr> =
                    from_proto::parse_exprs(&join.right_join_key, ctx, extension_codec)?;
                let join_type =
                    protobuf::JoinType::try_from(join.join_type).map_err(|_| {
                        proto_error(format!(
                            "Received a JoinNode message with unknown JoinType {}",
                            join.join_type
                        ))
                    })?;
                let join_constraint = protobuf::JoinConstraint::try_from(
                    join.join_constraint,
                )
                .map_err(|_| {
                    proto_error(format!(
                        "Received a JoinNode message with unknown JoinConstraint {}",
                        join.join_constraint
                    ))
                })?;
                let filter: Option<Expr> = join
                    .filter
                    .as_ref()
                    .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec))
                    .map_or(Ok(None), |v| v.map(Some))?;

                let builder = LogicalPlanBuilder::from(into_logical_plan!(
                    join.left,
                    ctx,
                    extension_codec
                )?);
                let builder = match join_constraint.into() {
                    JoinConstraint::On => builder.join_with_expr_keys(
                        into_logical_plan!(join.right, ctx, extension_codec)?,
                        join_type.into(),
                        (left_keys, right_keys),
                        filter,
                    )?,
                    JoinConstraint::Using => {
                        // The equijoin keys in using-join must be column.
                        let using_keys = left_keys
                            .into_iter()
                            .map(|key| {
                                key.try_as_col().cloned()
                                    .ok_or_else(|| internal_datafusion_err!(
                                        "Using join keys must be column references, got: {key:?}"
                                    ))
                            })
                            .collect::<Result<Vec<_>, _>>()?;
                        builder.join_using(
                            into_logical_plan!(join.right, ctx, extension_codec)?,
                            join_type.into(),
                            using_keys,
                        )?
                    }
                };

                builder.build()
            }
            LogicalPlanType::Union(union) => {
                if union.inputs.len() < 2 {
                    return  Err( DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, Union was require at least two input.",
                    )));
                }
                let (first, rest) = union.inputs.split_first().unwrap();
                let mut builder = LogicalPlanBuilder::from(
                    first.try_into_logical_plan(ctx, extension_codec)?,
                );

                for i in rest {
                    let plan = i.try_into_logical_plan(ctx, extension_codec)?;
                    builder = builder.union(plan)?;
                }
                builder.build()
            }
            LogicalPlanType::CrossJoin(crossjoin) => {
                let left = into_logical_plan!(crossjoin.left, ctx, extension_codec)?;
                let right = into_logical_plan!(crossjoin.right, ctx, extension_codec)?;

                LogicalPlanBuilder::from(left).cross_join(right)?.build()
            }
            LogicalPlanType::Extension(LogicalExtensionNode { node, inputs }) => {
                let input_plans: Vec<LogicalPlan> = inputs
                    .iter()
                    .map(|i| i.try_into_logical_plan(ctx, extension_codec))
                    .collect::<Result<_>>()?;

                let extension_node =
                    extension_codec.try_decode(node, &input_plans, ctx)?;
                Ok(LogicalPlan::Extension(extension_node))
            }
            LogicalPlanType::Distinct(distinct) => {
                let input: LogicalPlan =
                    into_logical_plan!(distinct.input, ctx, extension_codec)?;
                LogicalPlanBuilder::from(input).distinct()?.build()
            }
            LogicalPlanType::DistinctOn(distinct_on) => {
                let input: LogicalPlan =
                    into_logical_plan!(distinct_on.input, ctx, extension_codec)?;
                let on_expr =
                    from_proto::parse_exprs(&distinct_on.on_expr, ctx, extension_codec)?;
                let select_expr = from_proto::parse_exprs(
                    &distinct_on.select_expr,
                    ctx,
                    extension_codec,
                )?;
                let sort_expr = match distinct_on.sort_expr.len() {
                    0 => None,
                    _ => Some(from_proto::parse_sorts(
                        &distinct_on.sort_expr,
                        ctx,
                        extension_codec,
                    )?),
                };
                LogicalPlanBuilder::from(input)
                    .distinct_on(on_expr, select_expr, sort_expr)?
                    .build()
            }
            LogicalPlanType::ViewScan(scan) => {
                let schema: Schema = convert_required!(scan.schema)?;

                let mut projection = None;
                if let Some(columns) = &scan.projection {
                    let column_indices = columns
                        .columns
                        .iter()
                        .map(|name| schema.index_of(name))
                        .collect::<Result<Vec<usize>, _>>()?;
                    projection = Some(column_indices);
                }

                let input: LogicalPlan =
                    into_logical_plan!(scan.input, ctx, extension_codec)?;

                let definition = if !scan.definition.is_empty() {
                    Some(scan.definition.clone())
                } else {
                    None
                };

                let provider = ViewTable::new(input, definition);

                let table_name =
                    from_table_reference(scan.table_name.as_ref(), "ViewScan")?;

                LogicalPlanBuilder::scan(
                    table_name,
                    provider_as_source(Arc::new(provider)),
                    projection,
                )?
                .build()
            }
            LogicalPlanType::Prepare(prepare) => {
                let input: LogicalPlan =
                    into_logical_plan!(prepare.input, ctx, extension_codec)?;
                let data_types: Vec<DataType> = prepare
                    .data_types
                    .iter()
                    .map(DataType::try_from)
                    .collect::<Result<_, _>>()?;
                LogicalPlanBuilder::from(input)
                    .prepare(prepare.name.clone(), data_types)?
                    .build()
            }
            LogicalPlanType::DropView(dropview) => {
                Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView {
                    name: from_table_reference(dropview.name.as_ref(), "DropView")?,
                    if_exists: dropview.if_exists,
                    schema: Arc::new(convert_required!(dropview.schema)?),
                })))
            }
            LogicalPlanType::CopyTo(copy) => {
                let input: LogicalPlan =
                    into_logical_plan!(copy.input, ctx, extension_codec)?;

                let file_type: Arc<dyn FileType> = format_as_file_type(
                    extension_codec.try_decode_file_format(&copy.file_type, ctx)?,
                );

                Ok(LogicalPlan::Copy(dml::CopyTo {
                    input: Arc::new(input),
                    output_url: copy.output_url.clone(),
                    partition_by: copy.partition_by.clone(),
                    file_type,
                    options: Default::default(),
                }))
            }
            LogicalPlanType::Unnest(unnest) => {
                let input: LogicalPlan =
                    into_logical_plan!(unnest.input, ctx, extension_codec)?;
                Ok(LogicalPlan::Unnest(Unnest {
                    input: Arc::new(input),
                    exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(),
                    list_type_columns: unnest
                        .list_type_columns
                        .iter()
                        .map(|c| {
                            let recursion_item = c.recursion.as_ref().unwrap();
                            (
                                c.input_index as _,
                                ColumnUnnestList {
                                    output_column: recursion_item
                                        .output_column
                                        .as_ref()
                                        .unwrap()
                                        .into(),
                                    depth: recursion_item.depth as _,
                                },
                            )
                        })
                        .collect(),
                    struct_type_columns: unnest
                        .struct_type_columns
                        .iter()
                        .map(|c| *c as usize)
                        .collect(),
                    dependency_indices: unnest
                        .dependency_indices
                        .iter()
                        .map(|c| *c as usize)
                        .collect(),
                    schema: Arc::new(convert_required!(unnest.schema)?),
                    options: into_required!(unnest.options)?,
                }))
            }
            LogicalPlanType::RecursiveQuery(recursive_query_node) => {
                let static_term = recursive_query_node
                    .static_term
                    .as_ref()
                    .ok_or_else(|| DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, RecursiveQueryNode was missing required field static_term.",
                    )))?
                    .try_into_logical_plan(ctx, extension_codec)?;

                let recursive_term = recursive_query_node
                    .recursive_term
                    .as_ref()
                    .ok_or_else(|| DataFusionError::Internal(String::from(
                        "Protobuf deserialization error, RecursiveQueryNode was missing required field recursive_term.",
                    )))?
                    .try_into_logical_plan(ctx, extension_codec)?;

                Ok(LogicalPlan::RecursiveQuery(RecursiveQuery {
                    name: recursive_query_node.name.clone(),
                    static_term: Arc::new(static_term),
                    recursive_term: Arc::new(recursive_term),
                    is_distinct: recursive_query_node.is_distinct,
                }))
            }
            LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => {
                let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node;
                let schema = convert_required!(*schema)?;
                let cte_work_table = CteWorkTable::new(name.as_str(), Arc::new(schema));
                LogicalPlanBuilder::scan(
                    name.as_str(),
                    provider_as_source(Arc::new(cte_work_table)),
                    None,
                )?
                .build()
            }
            LogicalPlanType::Dml(dml_node) => Ok(LogicalPlan::Dml(
                datafusion::logical_expr::DmlStatement::new(
                    from_table_reference(dml_node.table_name.as_ref(), "DML ")?,
                    to_table_source(&dml_node.target, ctx, extension_codec)?,
                    dml_node.dml_type().into(),
                    Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?),
                ),
            )),
        }
    }