def main()

in perfkitbenchmarker/scripts/spark_sql_test_scripts/spark_sql_runner.py [0:0]


def main(args, results_logger_getter=get_results_logger):
  builder = sql.SparkSession.builder.appName('Spark SQL Query')
  if args.enable_hive:
    builder = builder.enableHiveSupport()
  query_streams = args.sql_queries
  if len(query_streams) > 1:
    # this guarantees all query streams will use more or less the same resources
    builder = builder.config('spark.scheduler.mode', 'FAIR')
  spark = builder.getOrCreate()
  if args.database:
    spark.catalog.setCurrentDatabase(args.database)
  for name, (fmt, options) in get_table_metadata(args).items():
    logging.info('Loading %s', name)
    spark.read.format(fmt).options(**options).load().createTempView(name)
  if args.table_cache:
    # This captures both tables in args.database and views from table_metadata
    for table in spark.catalog.listTables():
      spark.sql(
          'CACHE {lazy} TABLE {name}'.format(
              lazy='LAZY' if args.table_cache == 'lazy' else '', name=table.name
          )
      )
  if args.dump_spark_conf:
    logging.info(
        'Dumping the spark conf properties to %s', args.dump_spark_conf
    )
    props = [
        sql.Row(key=key, val=val)
        for key, val in spark.sparkContext.getConf().getAll()
    ]
    spark.createDataFrame(props).coalesce(1).write.mode('append').json(
        args.dump_spark_conf
    )

  threads = len(query_streams)
  executor = futures.ThreadPoolExecutor(max_workers=threads)
  result_futures = [
      executor.submit(
          run_sql_query, spark, stream, i, args.fail_on_query_execution_errors
      )
      for i, stream in enumerate(query_streams)
  ]
  futures.wait(result_futures)
  results = []
  for f in result_futures:
    results += f.result()

  if args.log_results:
    dumped_results = '\n'.join([
        '----@spark_sql_runner:results_start@----',
        json.dumps(results),
        '----@spark_sql_runner:results_end@----',
    ])
    results_logger = results_logger_getter(spark.sparkContext)
    results_logger.info(dumped_results)
  else:
    logging.info('Writing results to %s', args.report_dir)
    results_as_rows = [
        sql.Row(
            stream=r['stream'], query_id=r['query_id'], duration=r['duration']
        )
        for r in results
    ]
    spark.createDataFrame(results_as_rows).coalesce(1).write.mode(
        'append'
    ).json(args.report_dir)