tpch/tpcbench.py (170 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import ray
from datafusion_ray import DFRayContext, df_ray_runtime_env
from datafusion_ray.util import LocalValidator, prettify
from datetime import datetime
import json
import os
import time
def tpch_query(qnum: int) -> str:
query_path = os.path.join(os.path.dirname(__file__), "queries")
return open(os.path.join(query_path, f"q{qnum}.sql")).read()
def main(
qnum: int,
data_path: str,
concurrency: int,
batch_size: int,
partitions_per_processor: int | None,
processor_pool_min: int,
listing_tables: bool,
validate: bool,
output_path: str,
prefetch_buffer_size: int,
):
# Register the tables
table_names = [
"customer",
"lineitem",
"nation",
"orders",
"part",
"partsupp",
"region",
"supplier",
]
# Connect to a cluster
# use ray job submit
ray.init(runtime_env=df_ray_runtime_env)
ctx = DFRayContext(
batch_size=batch_size,
partitions_per_processor=partitions_per_processor,
prefetch_buffer_size=prefetch_buffer_size,
processor_pool_min=processor_pool_min,
processor_pool_max=1000,
)
local = LocalValidator()
ctx.set("datafusion.execution.target_partitions", f"{concurrency}")
# ctx.set("datafusion.execution.parquet.pushdown_filters", "true")
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
ctx.set("datafusion.execution.coalesce_batches", "false")
for table in table_names:
path = os.path.join(data_path, f"{table}.parquet")
print(f"Registering table {table} using path {path}")
if listing_tables:
ctx.register_listing_table(table, path)
local.register_listing_table(table, path)
else:
ctx.register_parquet(table, path)
local.register_parquet(table, path)
current_time_millis = int(datetime.now().timestamp() * 1000)
results_path = os.path.join(
output_path, f"datafusion-ray-tpch-{current_time_millis}.json"
)
print(f"Writing results to {results_path}")
results = {
"engine": "datafusion-ray",
"benchmark": "tpch",
"settings": {
"concurrency": concurrency,
"batch_size": batch_size,
"prefetch_buffer_size": prefetch_buffer_size,
"partitions_per_processor": partitions_per_processor,
},
"data_path": data_path,
"queries": {},
}
if validate:
results["validated"] = {}
queries = range(1, 23) if qnum == -1 else [qnum]
for qnum in queries:
sql = tpch_query(qnum)
statements = list(
filter(
lambda x: len(x) > 0, map(lambda x: x.strip(), sql.split(";"))
)
)
start_time = time.time()
all_batches = []
for sql in statements:
print("executing ", sql)
df = ctx.sql(sql)
all_batches.append(df.collect())
end_time = time.time()
results["queries"][qnum] = end_time - start_time
calculated = "\n".join([prettify(b) for b in all_batches])
print(calculated)
out_path = os.path.join(
output_path, f"datafusion_ray_tpch_q{qnum}_result.txt"
)
with open(out_path, "w") as f:
f.write(calculated)
if validate:
all_batches = []
for sql in statements:
all_batches.append(local.collect_sql(sql))
expected = "\n".join([prettify(b) for b in all_batches])
results["validated"][qnum] = calculated == expected
print(f"done with query {qnum}")
# write the results as we go, so you can peek at them
results_dump = json.dumps(results, indent=4)
with open(results_path, "w+") as f:
f.write(results_dump)
# write results to stdout
print(results_dump)
# give ray a moment to clean up
print("benchmark complete. sleeping for 3 seconds for ray to clean up")
time.sleep(3)
if validate and False in results["validated"].values():
# return a non zero return code if we did not validate all queries
print("Possible incorrect query result")
exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="DataFusion benchmark derived from TPC-H / TPC-DS"
)
parser.add_argument("--data", required=True, help="Path to data files")
parser.add_argument(
"--concurrency", required=True, help="Number of concurrent tasks"
)
parser.add_argument(
"--qnum", type=int, default=-1, help="TPCH query number, 1-22"
)
parser.add_argument("--listing-tables", action="store_true")
parser.add_argument("--validate", action="store_true")
parser.add_argument(
"--log-level", default="INFO", help="ERROR,WARN,INFO,DEBUG,TRACE"
)
parser.add_argument(
"--batch-size",
required=False,
default=8192,
help="Desired batch size output per stage",
)
parser.add_argument(
"--partitions-per-processor",
type=int,
help="partitions per DFRayProcessor",
)
parser.add_argument(
"--output-path",
type=str,
default=".",
help="directory to write output json",
)
parser.add_argument(
"--prefetch-buffer-size",
required=False,
default=0,
type=int,
help="How many batches each stage should eagerly buffer",
)
parser.add_argument(
"--processor-pool-min",
type=int,
help="Minimum number of DFRayProcessors to keep in pool",
)
args = parser.parse_args()
main(
args.qnum,
args.data,
int(args.concurrency),
int(args.batch_size),
args.partitions_per_processor,
args.processor_pool_min,
args.listing_tables,
args.validate,
args.output_path,
args.prefetch_buffer_size,
)