src/stage_reader.rs (148 lines of code) (raw):

use std::{fmt::Formatter, sync::Arc}; use arrow_flight::{FlightClient, Ticket}; use datafusion::common::{internal_datafusion_err, internal_err}; use datafusion::error::Result; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, }; use datafusion::{arrow::datatypes::SchemaRef, execution::SendableRecordBatchStream}; use futures::stream::TryStreamExt; use futures::StreamExt; use log::trace; use prost::Message; use crate::processor_service::ServiceClients; use crate::protobuf::FlightTicketData; use crate::util::CombinedRecordBatchStream; /// An [`ExecutionPlan`] that will produce a stream of batches fetched from another stage /// which is hosted by a [`crate::stage_service::StageService`] separated from a network boundary /// /// Note that discovery of the service is handled by populating an instance of [`crate::stage_service::ServiceClients`] /// and storing it as an extension in the [`datafusion::execution::TaskContext`] configuration. #[derive(Debug)] pub struct DFRayStageReaderExec { properties: PlanProperties, schema: SchemaRef, pub stage_id: usize, } impl DFRayStageReaderExec { pub fn try_new_from_input(input: Arc<dyn ExecutionPlan>, stage_id: usize) -> Result<Self> { let properties = input.properties().clone(); Self::try_new(properties.partitioning.clone(), input.schema(), stage_id) } pub fn try_new(partitioning: Partitioning, schema: SchemaRef, stage_id: usize) -> Result<Self> { let properties = PlanProperties::new( EquivalenceProperties::new(schema.clone()), Partitioning::UnknownPartitioning(partitioning.partition_count()), EmissionType::Incremental, Boundedness::Bounded, ); Ok(Self { properties, schema, stage_id, }) } } impl DisplayAs for DFRayStageReaderExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { write!( f, "RayStageReaderExec[{}] (output_partitioning={:?})", self.stage_id, self.properties().partitioning ) } } impl ExecutionPlan for DFRayStageReaderExec { fn schema(&self) -> SchemaRef { self.schema.clone() } fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { vec![] } fn name(&self) -> &str { "RayStageReaderExec" } fn as_any(&self) -> &dyn std::any::Any { self } fn properties(&self) -> &datafusion::physical_plan::PlanProperties { &self.properties } fn with_new_children( self: std::sync::Arc<Self>, _children: Vec<std::sync::Arc<dyn ExecutionPlan>>, ) -> datafusion::error::Result<std::sync::Arc<dyn ExecutionPlan>> { // TODO: handle more general case unimplemented!() } fn execute( &self, partition: usize, context: std::sync::Arc<datafusion::execution::TaskContext>, ) -> Result<SendableRecordBatchStream> { let name = format!("RayStageReaderExec[{}-{}]:", self.stage_id, partition); trace!("{name} execute"); let client_map = &context .session_config() .get_extension::<ServiceClients>() .ok_or(internal_datafusion_err!( "{name} Flight Client not in context" ))? .clone() .0; trace!("{name} client_map keys {:?}", client_map.keys()); let clients = client_map .get(&(self.stage_id, partition)) .ok_or(internal_datafusion_err!( "{} No flight clients found for {}:{}, have {:?}", name, self.stage_id, partition, client_map.keys() ))? .lock() .iter() .map(|c| { let inner_clone = c.inner().clone(); FlightClient::new_from_inner(inner_clone) }) .collect::<Vec<_>>(); let ftd = FlightTicketData { dummy: false, partition: partition as u64, }; let ticket = Ticket { ticket: ftd.encode_to_vec().into(), }; let schema = self.schema.clone(); let stream = async_stream::stream! { let mut error = false; let mut streams = vec![]; for mut client in clients { let name = name.clone(); trace!("{name} Getting flight stream" ); match client.do_get(ticket.clone()).await { Ok(flight_stream) => { trace!("{name} Got flight stream. headers:{:?}", flight_stream.headers()); let rbr_stream = RecordBatchStreamAdapter::new(schema.clone(), flight_stream .map_err(move |e| internal_datafusion_err!("{} Error consuming flight stream: {}", name, e))); streams.push(Box::pin(rbr_stream) as SendableRecordBatchStream); }, Err(e) => { error = true; yield internal_err!("{} Error getting flight stream: {}", name, e); } } } if !error { let mut combined = CombinedRecordBatchStream::new(schema.clone(),streams); while let Some(maybe_batch) = combined.next().await { yield maybe_batch; } } }; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema.clone(), stream, ))) } }