in runners/datafusion-comet/tpcbench.py [0:0]
def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str):
# Initialize a SparkSession
spark = SparkSession.builder \
.appName("DataFusion Comet Benchmark derived from TPC-H / TPC-DS") \
.getOrCreate()
# Register the tables
if benchmark == "tpch":
num_queries = 22
table_names = ["customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier"]
elif benchmark == "tpcds":
num_queries = 99
table_names = ["call_center", "catalog_page", "catalog_returns", "catalog_sales", "customer",
"customer_address", "customer_demographics", "date_dim", "time_dim", "household_demographics",
"income_band", "inventory", "item", "promotion", "reason", "ship_mode", "store", "store_returns",
"store_sales", "warehouse", "web_page", "web_returns", "web_sales", "web_site"]
else:
raise "invalid benchmark"
for table in table_names:
path = f"{data_path}/{table}.parquet"
print(f"Registering table {table} using path {path}")
df = spark.read.parquet(path)
df.createOrReplaceTempView(table)
conf_dict = {k: v for k, v in spark.sparkContext.getConf().getAll()}
results = {
'engine': 'datafusion-comet',
'benchmark': benchmark,
'data_path': data_path,
'query_path': query_path,
'spark_conf': conf_dict,
}
for iteration in range(0, iterations):
print(f"Starting iteration {iteration} of {iterations}")
for query in range(1, num_queries+1):
spark.sparkContext.setJobDescription(f"{benchmark} q{query}")
# read text file
if query == 72:
# use version with sensible join order
path = f"{query_path}/q{query}_optimized.sql"
else:
path = f"{query_path}/q{query}.sql"
print(f"Reading query {query} using path {path}")
with open(path, "r") as f:
text = f.read()
# each file can contain multiple queries
queries = text.split(";")
start_time = time.time()
for sql in queries:
sql = sql.strip().replace("create view", "create temp view")
if len(sql) > 0:
print(f"Executing: {sql}")
df = spark.sql(sql)
rows = df.collect()
print(f"Query {query} returned {len(rows)} rows")
end_time = time.time()
print(f"Query {query} took {end_time - start_time} seconds")
# store timings in list and later add option to run > 1 iterations
query_timings = results.setdefault(query, [])
query_timings.append(end_time - start_time)
str = json.dumps(results, indent=4)
current_time_millis = int(datetime.now().timestamp() * 1000)
results_path = f"{output}/spark-{benchmark}-{current_time_millis}.json"
print(f"Writing results to {results_path}")
with open(results_path, "w") as f:
f.write(str)
# Stop the SparkSession
spark.stop()