k8s/spark_tpcbench.py (90 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import argparse from datetime import datetime import json from pyspark.sql import SparkSession import time import sys def main(benchmark: str, data_path: str, query_path: str, output_path: str, name: str): # Initialize a SparkSession spark = SparkSession.builder \ .appName( f"{name} benchmark derived from {benchmark}") \ .getOrCreate() spark.conf.set("spark.hadoop.fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider") spark.conf.set("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") # Register the tables num_queries = 22 table_names = [ "customer", "lineitem", "nation", "orders", "part", "partsupp", "region", "supplier", ] for table in table_names: path = f"{data_path}/{table}.parquet" print(f"Registering table {table} using path {path}") df = spark.read.parquet(path) df.createOrReplaceTempView(table) conf_dict = {k: v for k, v in spark.sparkContext.getConf().getAll()} results = { "engine": "spark", "benchmark": benchmark, "data_path": data_path, "query_path": query_path, "spark_conf": conf_dict, "queries": {}, } iter_start_time = time.time() for query in range(1, num_queries + 1): spark.sparkContext.setJobDescription(f"{benchmark} q{query}") # if query == 9: # continue # read text file path = f"{query_path}/q{query}.sql" # if query == 72: # # use version with sensible join order # path = f"{query_path}/q{query}_optimized.sql" print(f"Reading query {query} using path {path}") with open(path, "r") as f: text = f.read() # each file can contain multiple queries queries = list( filter(lambda x: len(x) > 0, map(lambda x: x.strip(), text.split(";"))) ) start_time = time.time() for sql in queries: sql = sql.strip().replace("create view", "create temp view") if len(sql) > 0: print(f"Executing: {sql}") df = spark.sql(sql) rows = df.collect() end_time = time.time() out_path = f"{output_path}/{name}_{benchmark}_q{query}_result.txt" # fIXME: concat output for all queries. For example q15 has multiple out = df._show_string(100000) with open(out_path, "w") as f: f.write(out) print(f"Query {query} took {end_time - start_time} seconds") results["queries"][str(query)] = end_time - start_time print(json.dumps(results, indent=4)) iter_end_time = time.time() print(f"total took {round(iter_end_time - iter_start_time,2)} seconds") out = json.dumps(results, indent=4) current_time_millis = int(datetime.now().timestamp() * 1000) results_path = f"{output_path}/{name}-{benchmark}-{current_time_millis}.json" print(f"Writing results to {results_path}") with open(results_path, "w") as f: f.write(out) # Stop the SparkSession spark.stop() if __name__ == "__main__": print(f"got arguments {sys.argv}") print(f"python version {sys.version}") print(f"python versioninfo {sys.version_info}") parser = argparse.ArgumentParser( description="DataFusion benchmark derived from TPC-H / TPC-DS" ) parser.add_argument( "--benchmark", required=True, help="Benchmark to run (tpch or tpcds)" ) parser.add_argument("--data", required=True, help="Path to data files") parser.add_argument("--queries", required=True, help="Path to query files") parser.add_argument("--output", required=True, help="Path to write output") parser.add_argument( "--name", required=True, help="Prefix for result file e.g. spark/comet/gluten" ) args = parser.parse_args() print(f"parsed is {args}") main(args.benchmark, args.data, args.queries, args.output, args.name)