runners/datafusion-comet/tpcbench.py (73 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import argparse from datetime import datetime import json from pyspark.sql import SparkSession import time def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str): # Initialize a SparkSession spark = SparkSession.builder \ .appName("DataFusion Comet Benchmark derived from TPC-H / TPC-DS") \ .getOrCreate() # Register the tables if benchmark == "tpch": num_queries = 22 table_names = ["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"] elif benchmark == "tpcds": num_queries = 99 table_names = ["call_center", "catalog_page", "catalog_returns", "catalog_sales", "customer", "customer_address", "customer_demographics", "date_dim", "time_dim", "household_demographics", "income_band", "inventory", "item", "promotion", "reason", "ship_mode", "store", "store_returns", "store_sales", "warehouse", "web_page", "web_returns", "web_sales", "web_site"] else: raise "invalid benchmark" for table in table_names: path = f"{data_path}/{table}.parquet" print(f"Registering table {table} using path {path}") df = spark.read.parquet(path) df.createOrReplaceTempView(table) conf_dict = {k: v for k, v in spark.sparkContext.getConf().getAll()} results = { 'engine': 'datafusion-comet', 'benchmark': benchmark, 'data_path': data_path, 'query_path': query_path, 'spark_conf': conf_dict, } for iteration in range(0, iterations): print(f"Starting iteration {iteration} of {iterations}") for query in range(1, num_queries+1): spark.sparkContext.setJobDescription(f"{benchmark} q{query}") # read text file if query == 72: # use version with sensible join order path = f"{query_path}/q{query}_optimized.sql" else: path = f"{query_path}/q{query}.sql" print(f"Reading query {query} using path {path}") with open(path, "r") as f: text = f.read() # each file can contain multiple queries queries = text.split(";") start_time = time.time() for sql in queries: sql = sql.strip().replace("create view", "create temp view") if len(sql) > 0: print(f"Executing: {sql}") df = spark.sql(sql) rows = df.collect() print(f"Query {query} returned {len(rows)} rows") end_time = time.time() print(f"Query {query} took {end_time - start_time} seconds") # store timings in list and later add option to run > 1 iterations query_timings = results.setdefault(query, []) query_timings.append(end_time - start_time) str = json.dumps(results, indent=4) current_time_millis = int(datetime.now().timestamp() * 1000) results_path = f"{output}/spark-{benchmark}-{current_time_millis}.json" print(f"Writing results to {results_path}") with open(results_path, "w") as f: f.write(str) # Stop the SparkSession spark.stop() if __name__ == "__main__": parser = argparse.ArgumentParser(description="DataFusion benchmark derived from TPC-H / TPC-DS") parser.add_argument("--benchmark", required=True, help="Benchmark to run (tpch or tpcds)") parser.add_argument("--data", required=True, help="Path to data files") parser.add_argument("--queries", required=True, help="Path to query files") parser.add_argument("--iterations", required=False, default="1", help="How many iterations to run") parser.add_argument("--output", required=True, help="Path to write output") args = parser.parse_args() main(args.benchmark, args.data, args.queries, int(args.iterations), args.output)