datafusion_ray/core.py (436 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import defaultdict
from dataclasses import dataclass
import logging
import os
import pyarrow as pa
import asyncio
import ray
import time
from .friendly import new_friendly_name
from datafusion_ray._datafusion_ray_internal import (
DFRayContext as DFRayContextInternal,
DFRayDataFrame as DFRayDataFrameInternal,
prettify,
)
def setup_logging():
import logging
logging.addLevelName(5, "TRACE")
log_level = os.environ.get("DATAFUSION_RAY_LOG_LEVEL", "WARN").upper()
# this logger gets captured and routed to rust. See src/lib.rs
logging.getLogger("core_py").setLevel(log_level)
logging.basicConfig()
setup_logging()
_log_level = os.environ.get("DATAFUSION_RAY_LOG_LEVEL", "ERROR").upper()
_rust_backtrace = os.environ.get("RUST_BACKTRACE", "0")
df_ray_runtime_env = {
"worker_process_setup_hook": setup_logging,
"env_vars": {
"DATAFUSION_RAY_LOG_LEVEL": _log_level,
"RAY_worker_niceness": "0",
"RUST_BACKTRACE": _rust_backtrace,
},
}
log = logging.getLogger("core_py")
def call_sync(coro):
"""call a coroutine in the current event loop or run a new one, and synchronously
return the result"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
else:
return loop.run_until_complete(coro)
# work around for https://github.com/ray-project/ray/issues/31606
async def _ensure_coro(maybe_obj_ref):
return await maybe_obj_ref
async def wait_for(coros, name=""):
"""Wait for all coros to complete and return their results.
Does not preserve ordering."""
return_values = []
# wrap the coro in a task to work with python 3.10 and 3.11+ where asyncio.wait semantics
# changed to not accept any awaitable
start = time.time()
done, _ = await asyncio.wait(
[asyncio.create_task(_ensure_coro(c)) for c in coros]
)
end = time.time()
log.info(f"waiting for {name} took {end - start}s")
for d in done:
e = d.exception()
if e is not None:
log.error(f"Exception waiting {name}: {e}")
raise e
else:
return_values.append(d.result())
return return_values
class DFRayProcessorPool:
"""A pool of DFRayProcessor actors that can be acquired and released"""
# TODO: We can probably manage this set in a better way
# This is not a threadsafe implementation, though the DFRayContextSupervisor accesses it
# from a single thread
#
# This is simple though and will suffice for now
def __init__(self, min_processors: int, max_processors: int):
self.min_processors = min_processors
self.max_processors = max_processors
# a map of processor_key (a random identifier) to stage actor reference
self.pool = {}
# a map of processor_key to listening address
self.addrs = {}
# holds object references from the start_up method for each processor
# we know all processors are listening when all of these refs have
# been waited on. When they are ready we remove them from this set
self.processors_started = set()
# an event that is set when all processors are ready to serve
self.processors_ready = asyncio.Event()
# processors that are started but we need to get their address
self.need_address = set()
# processors that we have the address for but need to start serving
self.need_serving = set()
# processors in use
self.acquired = set()
# processors available
self.available = set()
for _ in range(min_processors):
self._new_processor()
log.info(
f"created ray processor pool (min_processors: {min_processors}, max_processors: {max_processors})"
)
async def start(self):
if not self.processors_ready.is_set():
await self._wait_for_processors_started()
await self._wait_for_get_addrs()
await self._wait_for_serve()
self.processors_ready.set()
async def wait_for_ready(self):
await self.processors_ready.wait()
async def acquire(self, need=1):
processor_keys = []
have = len(self.available)
total = len(self.available) + len(self.acquired)
can_make = self.max_processors - total
need_to_make = need - have
if need_to_make > can_make:
raise Exception(
f"Cannot allocate processors above {self.max_processors}"
)
if need_to_make > 0:
log.debug(f"creating {need_to_make} additional processors")
for _ in range(need_to_make):
self._new_processor()
await wait_for([self.start()], "waiting for created processors")
assert len(self.available) >= need
for _ in range(need):
processor_key = self.available.pop()
self.acquired.add(processor_key)
processor_keys.append(processor_key)
processors = [self.pool[sk] for sk in processor_keys]
addrs = [self.addrs[sk] for sk in processor_keys]
return (processors, processor_keys, addrs)
def release(self, processor_keys: list[str]):
for processor_key in processor_keys:
self.acquired.remove(processor_key)
self.available.add(processor_key)
def _new_processor(self):
self.processors_ready.clear()
processor_key = new_friendly_name()
log.debug(f"starting processor: {processor_key}")
processor = DFRayProcessor.options(
name=f"Processor : {processor_key}"
).remote(processor_key)
self.pool[processor_key] = processor
self.processors_started.add(processor.start_up.remote())
self.available.add(processor_key)
async def _wait_for_processors_started(self):
log.info("waiting for processors to be ready")
started_keys = await wait_for(
self.processors_started, "processors to be started"
)
# we need the addresses of these processors still
self.need_address.update(set(started_keys))
# we've started all the processors we know about
self.processors_started = set()
log.info("processors are all listening")
async def _wait_for_get_addrs(self):
# get the addresses in a pipelined fashion
refs = []
processor_keys = []
for processor_key in self.need_address:
processor = self.pool[processor_key]
refs.append(processor.addr.remote())
processor_keys.append(processor_key)
self.need_serving.add(processor_key)
addrs = await wait_for(refs, "processor addresses")
for key, addr in addrs:
self.addrs[key] = addr
self.need_address = set()
async def _wait_for_serve(self):
log.info("running processors")
try:
for processor_key in self.need_serving:
log.info(f"starting serving of processor {processor_key}")
processor = self.pool[processor_key]
processor.serve.remote()
self.need_serving = set()
except Exception as e:
log.error(f"ProcessorPool: Uhandled Exception in serve: {e}")
raise e
async def all_done(self):
log.info("calling processor all done")
refs = [
processor.all_done.remote() for processor in self.pool.values()
]
await wait_for(refs, "processors to be all done")
log.info("all processors shutdown")
@ray.remote(num_cpus=0.01, scheduling_strategy="SPREAD")
class DFRayProcessor:
def __init__(self, processor_key):
self.processor_key = processor_key
# import this here so ray doesn't try to serialize the rust extension
from datafusion_ray._datafusion_ray_internal import (
DFRayProcessorService,
)
self.processor_service = DFRayProcessorService(processor_key)
async def start_up(self):
# this method is sync
self.processor_service.start_up()
return self.processor_key
async def all_done(self):
await self.processor_service.all_done()
async def addr(self):
return (self.processor_key, self.processor_service.addr())
async def update_plan(
self,
stage_id: int,
stage_addrs: dict[int, dict[int, list[str]]],
partition_group: list[int],
plan_bytes: bytes,
):
await self.processor_service.update_plan(
stage_id,
stage_addrs,
partition_group,
plan_bytes,
)
async def serve(self):
log.info(
f"[{self.processor_key}] serving on {self.processor_service.addr()}"
)
await self.processor_service.serve()
log.info(f"[{self.processor_key}] done serving")
@dataclass
class StageData:
stage_id: int
plan_bytes: bytes
partition_group: list[int]
child_stage_ids: list[int]
num_output_partitions: int
full_partitions: bool
@dataclass
class InternalStageData:
stage_id: int
plan_bytes: bytes
partition_group: list[int]
child_stage_ids: list[int]
num_output_partitions: int
full_partitions: bool
remote_processor: ... # ray.actor.ActorHandle[DFRayProcessor]
remote_addr: str
def __str__(self):
return f"""Stage: {self.stage_id}, pg: {self.partition_group}, child_stages:{self.child_stage_ids}, listening addr:{self.remote_addr}"""
@ray.remote(num_cpus=0.01, scheduling_strategy="SPREAD")
class DFRayContextSupervisor:
def __init__(
self,
processor_pool_min: int,
processor_pool_max: int,
) -> None:
log.info(
f"Creating DFRayContextSupervisor processor_pool_min: {processor_pool_min}"
)
self.pool = DFRayProcessorPool(processor_pool_min, processor_pool_max)
self.stages: dict[str, InternalStageData] = {}
log.info("Created DFRayContextSupervisor")
async def start(self):
await self.pool.start()
async def wait_for_ready(self):
await self.pool.wait_for_ready()
async def get_stage_addrs(self, stage_id: int):
addrs = [
sd.remote_addr
for sd in self.stages.values()
if sd.stage_id == stage_id
]
return addrs
async def new_query(
self,
stage_datas: list[StageData],
):
if len(self.stages) > 0:
self.pool.release(list(self.stages.keys()))
remote_processors, remote_processor_keys, remote_addrs = (
await self.pool.acquire(len(stage_datas))
)
self.stages = {}
for i, sd in enumerate(stage_datas):
remote_processor = remote_processors[i]
remote_processor_key = remote_processor_keys[i]
remote_addr = remote_addrs[i]
self.stages[remote_processor_key] = InternalStageData(
sd.stage_id,
sd.plan_bytes,
sd.partition_group,
sd.child_stage_ids,
sd.num_output_partitions,
sd.full_partitions,
remote_processor,
remote_addr,
)
# sort out the mess of who talks to whom and ensure we can supply the correct
# addresses to each of them
addrs_by_stage_key = await self.sort_out_addresses()
if log.level <= logging.DEBUG:
# TODO: string builder here
out = ""
for stage_key, stage in self.stages.items():
out += f"[{stage_key}]: {stage}\n"
out += f"child addrs: {addrs_by_stage_key[stage_key]}\n"
log.debug(out)
refs = []
# now tell the stages what they are doing for this query
for stage_key, isd in self.stages.items():
log.info(f"going to update plan for {stage_key}")
kid = addrs_by_stage_key[stage_key]
refs.append(
isd.remote_processor.update_plan.remote(
isd.stage_id,
{
stage_id: val["child_addrs"]
for (stage_id, val) in kid.items()
},
isd.partition_group,
isd.plan_bytes,
)
)
log.info("that's all of them")
await wait_for(refs, "updating plans")
async def sort_out_addresses(self):
"""Iterate through our stages and gather all of their listening addresses.
Then, provide the addresses to of peer stages to each stage.
"""
addrs_by_stage_key = {}
for stage_key, isd in self.stages.items():
stage_addrs = defaultdict(dict)
# using "isd" as shorthand to denote InternalStageData as a reminder
for child_stage_id in isd.child_stage_ids:
addrs = defaultdict(list)
child_stage_keys, child_stage_datas = zip(
*filter(
lambda x: x[1].stage_id == child_stage_id,
self.stages.items(),
)
)
output_partitions = [
c_isd.num_output_partitions for c_isd in child_stage_datas
]
# sanity check
assert all(
[op == output_partitions[0] for op in output_partitions]
)
output_partitions = output_partitions[0]
for child_stage_isd in child_stage_datas:
if child_stage_isd.full_partitions:
for partition in range(output_partitions):
# this stage is the definitive place to read this output partition
addrs[partition] = [child_stage_isd.remote_addr]
else:
for partition in range(output_partitions):
# this output partition must be gathered from all stages with this stage_id
addrs[partition] = [
c.remote_addr for c in child_stage_datas
]
stage_addrs[child_stage_id]["child_addrs"] = addrs
# not necessary but useful for debug logs
stage_addrs[child_stage_id]["stage_keys"] = child_stage_keys
addrs_by_stage_key[stage_key] = stage_addrs
return addrs_by_stage_key
async def all_done(self):
await self.pool.all_done()
class DFRayDataFrame:
def __init__(
self,
internal_df: DFRayDataFrameInternal,
supervisor, # ray.actor.ActorHandle[DFRayContextSupervisor],
batch_size=8192,
partitions_per_processor: int | None = None,
prefetch_buffer_size=0,
):
self.df = internal_df
self.supervisor = supervisor
self._stages = None
self._batches = None
self.batch_size = batch_size
self.partitions_per_processor = partitions_per_processor
self.prefetch_buffer_size = prefetch_buffer_size
def stages(self):
# create our coordinator now, which we need to create stages
if not self._stages:
self._stages = self.df.stages(
self.batch_size,
self.prefetch_buffer_size,
self.partitions_per_processor,
)
return self._stages
def schema(self):
return self.df.schema()
def execution_plan(self):
return self.df.execution_plan()
def logical_plan(self):
return self.df.logical_plan()
def optimized_logical_plan(self):
return self.df.optimized_logical_plan()
def collect(self) -> list[pa.RecordBatch]:
if not self._batches:
t1 = time.time()
self.stages()
t2 = time.time()
log.debug(f"creating stages took {t2 - t1}s")
last_stage_id = max([stage.stage_id for stage in self._stages])
log.debug(f"last stage is {last_stage_id}")
self.create_ray_stages()
last_stage_addrs = ray.get(
self.supervisor.get_stage_addrs.remote(last_stage_id)
)
log.debug(f"last stage addrs {last_stage_addrs}")
reader = self.df.read_final_stage(
last_stage_id, last_stage_addrs[0]
)
log.debug("got reader")
self._batches = list(reader)
return self._batches
def show(self) -> None:
batches = self.collect()
print(prettify(batches))
def create_ray_stages(self):
stage_datas = []
# note, whereas the PyDataFrameStage object contained in self.stages()
# holds information for a numbered stage,
# when we tell the supervisor about our query, it wants a StageData
# object per actor that will be created. Hence the loop over partition_groups
for stage in self.stages():
for partition_group in stage.partition_groups:
stage_datas.append(
StageData(
stage.stage_id,
stage.plan_bytes(),
partition_group,
stage.child_stage_ids,
stage.num_output_partitions,
stage.full_partitions,
)
)
ref = self.supervisor.new_query.remote(stage_datas)
call_sync(wait_for([ref], "creating ray stages"))
class DFRayContext:
def __init__(
self,
batch_size: int = 8192,
prefetch_buffer_size: int = 0,
partitions_per_processor: int | None = None,
processor_pool_min: int = 1,
processor_pool_max: int = 100,
) -> None:
self.ctx = DFRayContextInternal()
self.batch_size = batch_size
self.partitions_per_processor = partitions_per_processor
self.prefetch_buffer_size = prefetch_buffer_size
self.supervisor = DFRayContextSupervisor.options(
name="RayContextSupersisor",
).remote(
processor_pool_min,
processor_pool_max,
)
# start up our super visor and don't check in on it until its
# time to query, then we will await this ref
start_ref = self.supervisor.start.remote()
# ensure we are ready
s = time.time()
call_sync(wait_for([start_ref], "RayContextSupervisor start"))
e = time.time()
log.info(
f"RayContext::__init__ waiting for supervisor to be ready took {e - s}s"
)
def register_parquet(self, name: str, path: str):
"""
Register a Parquet file with the given name and path.
The path can be a local filesystem path, absolute filesystem path, or a url.
If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.
For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc
Parameters:
path (str): The file path to the Parquet file.
name (str): The name to register the Parquet file under.
"""
self.ctx.register_parquet(name, path)
def register_csv(self, name: str, path: str):
"""
Register a csvfile with the given name and path.
The path can be a local filesystem path, absolute filesystem path, or a url.
If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.
For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc
Parameters:
path (str): The file path to the csv file.
name (str): The name to register the Parquet file under.
"""
self.ctx.register_csv(name, path)
def register_listing_table(
self, name: str, path: str, file_extention="parquet"
):
"""
Register a directory of parquet files with the given name.
The path can be a local filesystem path, absolute filesystem path, or a url.
If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.
For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc
Parameters:
path (str): The file path to the Parquet file directory
name (str): The name to register the Parquet file under.
"""
self.ctx.register_listing_table(name, path, file_extention)
def sql(self, query: str) -> DFRayDataFrame:
df = self.ctx.sql(query)
return DFRayDataFrame(
df,
self.supervisor,
self.batch_size,
self.partitions_per_processor,
self.prefetch_buffer_size,
)
def set(self, option: str, value: str) -> None:
self.ctx.set(option, value)
def __del__(self):
log.info("DFRayContext, cleaning up remote resources")
ref = self.supervisor.all_done.remote()
call_sync(wait_for([ref], "DFRayContextSupervisor all done"))