def main()

in source/benchmark-sample/emr-benchmark.py [0:0]


def main():
    class NegateAction(argparse.Action):
        def __call__(self, parser, ns, values, option):
            setattr(ns, self.dest, option[2:4] != 'no')

    parser = argparse.ArgumentParser(description='run emr benchmark all in one')
    parser.add_argument('--engine', dest='engine', default='spark', nargs='?', choices=['hive', 'spark'])
    parser.add_argument('--show-plot-only', dest='show_plot_only', action='store_true', default=False,
                        help='will show plot only if set')
    parser.add_argument('--cleanup', '--no-cleanup', dest='cleanup', action=NegateAction, default=False, nargs=0,
                        help='whether to clean up benchmark existing data')
    parser.add_argument('--gendata', '--no-gendata', dest='gendata', action=NegateAction, default=True, nargs=0,
                        help='whether to generate benchmark data')
    parser.add_argument('--restore', dest='restore', action='store_true', default=False,
                        help='whether to restore the benchmark database from existing data')
    parser.add_argument('--scale', metavar='N', type=int, nargs='?', default=2,
                        help='an integer for the accumulator')
    parser.add_argument('--s3', '--no-s3', dest='s3', action=NegateAction, default=False, nargs=0,
                        help='whether to enable s3 benchmark')
    parser.add_argument('--jfs', '--no-jfs', dest='jfs',  action=NegateAction, default=False, nargs=0,
                        help='whether to enable jfs benchmark')
    parser.add_argument('--hdfs', '--no-hdfs', dest='hdfs', action=NegateAction, default=False, nargs=0,
                        help='whether to enable hdfs benchmark')

    args = parser.parse_args()

    now = datetime.now()
    nowstr = now.strftime('%Y-%m-%d_%H-%M-%S')
    scale = args.scale
    s3_uri = ''
    jfs_uri = ''
    hdfs_uri = ''
    fs_protocols = []

    if args.jfs:
        jfs_uri = 'jfs://%s/' % (os.environ.get('JFS_VOL')
                                 or require_input('Enter your JuiceFS volume name for benchmark'))
        fs_protocols.append('jfs')

    if args.s3:
        s3_bucket = require_input('Enter your S3 bucket name for benchmark. Will create it if it doesn\'t exist')
        s3_uri = 's3://%s/' % s3_bucket

        sh(f'''\
        if ! aws s3api head-bucket --bucket="{s3_bucket}"; then
            echo "{s3_bucket} not exist, create it firstly"
            aws s3 mb {s3_uri}
        fi
        ''')
        fs_protocols.append('s3')

    if args.hdfs:
        hdfs_uri = 'hdfs://$(hostname)/'
        fs_protocols.append('hdfs')

    if not args.show_plot_only:
        if args.jfs:
            run_query(jfs_uri, scale, 'jfs',
                      cleanup=args.cleanup,
                      gendata=args.gendata,
                      restore=args.restore,
                      engine=args.engine)
        if args.s3:
            run_query(s3_uri, scale, 's3',
                      cleanup=args.cleanup,
                      gendata=args.gendata,
                      restore=args.restore,
                      engine=args.engine)
        if args.hdfs:
            run_query(hdfs_uri, scale, 'hdfs',
                      cleanup=args.cleanup,
                      gendata=args.gendata,
                      restore=args.restore,
                      engine=args.engine)

    if len(fs_protocols):
        queries = glob.glob(os.path.join(QUERIES_DIR, '*.sql'))
        queries = sorted(queries, key=lambda n: int(re.findall(r'\d+', n)[0]))

        def tpcds_setup_res_name(query, key, fs_proto):
            return f'tpcds-setup.sh.{fs_proto}.{scale}.res'

        def query_sql_res_name(query, key, fs_proto):
            return f'{query}.{key}.{fs_proto}.{scale}.res'

        mkplot(['tpcds-setup.sh'], 'tpcds-setup.sh duration', fs_protocols,
               filename=f'tpcds-setup-{scale}-duration.{nowstr}.res',
               get_name_func=tpcds_setup_res_name)
        mkplot(queries, f'{args.engine}.parquet', fs_protocols,
               filename=f'{args.engine}-parquet-{scale}-benchmark.{nowstr}.res',
               get_name_func=query_sql_res_name)
        mkplot(queries, f'{args.engine}.orc', fs_protocols,
               filename=f'{args.engine}-orc-{scale}-benchmark.{nowstr}.res',
               get_name_func=query_sql_res_name)
    else:
        print('nothing to do')