#!/usr/bin/python3

"""Read a table from BigQuery and write it as parquet."""

import json
import re
import sys
from argparse import ArgumentParser
from textwrap import dedent

try:
    from google.cloud import bigquery
except ImportError as e:
    bigquery = None  # type: ignore
    bigquery_error = e

parser = ArgumentParser(description=__doc__)
parser.add_argument("table", help="BigQuery table to read")
parser.add_argument(
    "--avro-path",
    help="Avro export path that should be read instead of BigQuery Storage API.",
)
parser.add_argument(
    "--dataset",
    dest="dataset",
    default="telemetry",
    help='BigQuery dataset that contains TABLE. Defaults to "telemetry".',
)
parser.add_argument(
    "--destination",
    default="s3://telemetry-parquet/",
    help="The path where parquet will be written. This will have table and"
    " --static-partitions appended as directories. If this starts with s3:// it will"
    " be replaced with s3a://.",
)
parser.add_argument(
    "--destination-table",
    default=None,
    help="The name of the destination parquet table. Defaults to the source table name."
    " If this ends with a version by matching /_v[0-9]+$/ it will be converted to"
    " a directory in the output.",
)
parser.add_argument(
    "--drop",
    default=[],
    dest="drop",
    nargs="+",
    help="A list of fields to exclude from the output. Passed to df.drop(*_) after"
    " --select is applied.",
)
parser.add_argument(
    "--dry-run",
    dest="dry_run",
    action="store_true",
    help="Print the spark job that would be run, without running it, and exit.",
)
parser.add_argument(
    "--filter",
    default=[],
    dest="filter",
    nargs="+",
    help="A list of conditions to limit the input before it reaches spark. Passed to"
    ' spark.read.format("bigquery").option("filter", " AND ".join(_)). If a field is'
    " referenced here and later dropped via --drop or --static-partitions then it must"
    " be referenced in --where.",
)
parser.add_argument(
    "--partition-by",
    default=["submission_date"],
    dest="partition_by",
    nargs="+",
    help="A list of fields on which to dynamically partition the output. Passed to"
    ' df.write.partitionBy(*_). Defaults to ["submission_date"]. Fields specified in'
    " --static-partitions will be ignored.",
)
parser.add_argument(
    "--select",
    default=["*"],
    dest="select",
    nargs="+",
    help="A list of all fields to include in the output."
    " Passed to df.selectExpr(*_) after --replace is applied and before --drop."
    " Can include sql expressions.",
)
parser.add_argument(
    "--maps-from-entries",
    action="store_true",
    help="Recursively convert repeated key-value structs with maps.",
)
parser.add_argument(
    "--bigint-columns",
    nargs="*",
    help="A list of columns that should remain BIGINT in the output, while any other "
    " BIGINT columns are converted to INT. If unspecified, all columns remain BIGINT.",
)
parser.add_argument(
    "--replace",
    default=[],
    nargs="+",
    help="A list of expressions that modify columns in the output. Passed to"
    " df.withColumnExpr(_) one at a time after --where applied and before --select."
    " Can include sql expressions.",
)
parser.add_argument(
    "--static-partitions",
    dest="static_partitions",
    default=[],
    nargs="+",
    help="Static partitions specified as FIELD=VALUE that will be appended to"
    " --destination after table, with FIELD dropped from the input.",
)
parser.add_argument(
    "--submission-date",
    dest="submission_date",
    help="Short for: --filter \"submission_date = DATE 'SUBMISSION_DATE'\""
    " --where \"submission_date = DATE 'SUBMISSION_DATE'\""
    " --static-partitions submission_date=SUBMISSION_DATE",
)
parser.add_argument(
    "--where",
    dest="where",
    default="TRUE",
    help="An expression to limit the output. Passed to df.where(_). If a field is"
    " referenced in filter and later dropped via --drop or --static-partitions then it"
    " must be referenced here.",
)
parser.add_argument(
    "--write-mode",
    dest="write_mode",
    default="overwrite",
    help='Passed to df.write.mode(_). Defaults to "overwrite".',
)
parser.add_argument(
    "--partition-overwrite-mode",
    default="STATIC",
    type=str.upper,
    help='Passed to spark.conf.set("spark.sql.sources.partitionOverwriteMode", _).'
    ' Defaults to "STATIC".',
)


def transform_field(
    field, maps_from_entries=False, bigint_columns=None, transform_layer=0, *prefix
):
    """
    Generate spark SQL to recursively convert fields types.

    If maps_from_entries is True, convert repeated key-value structs to maps.

    If bigint_columns is a list, convert non-matching BIGINT columns to INT.
    """
    transformed = False
    result = full_name = ".".join(prefix + (field.name,))
    repeated = field.mode == "REPEATED"
    if repeated:
        if transform_layer > 0:
            prefix = (f"_{transform_layer}",)
        else:
            prefix = ("_",)
        transform_layer += 1
    else:
        prefix = (*prefix, field.name)
    if field.field_type == "RECORD":
        if bigint_columns is not None:
            # get the bigint_columns nested under this field
            prefix_len = len(field.name) + 1
            bigint_columns = [
                column[prefix_len:]
                for column in bigint_columns
                if column.startswith(field.name + ".")
            ]
        subfields = [
            transform_field(
                subfield, maps_from_entries, bigint_columns, transform_layer, *prefix
            )
            for subfield in field.fields
        ]
        if any(subfield_transformed for _, subfield_transformed in subfields):
            transformed = True
            fields = ", ".join(transform for transform, _ in subfields)
            result = f"STRUCT({fields})"
            if repeated:
                result = f"TRANSFORM({full_name}, {prefix[0]} -> {result})"
        if maps_from_entries:
            if repeated and {"key", "value"} == {f.name for f in field.fields}:
                transformed = True
                result = f"MAP_FROM_ENTRIES({result})"
    elif field.field_type == "INTEGER":
        if bigint_columns is not None and field.name not in bigint_columns:
            transformed = True
            if repeated:
                result = f"TRANSFORM({full_name}, {prefix[0]} -> INT({prefix[0]}))"
            else:
                result = f"INT({full_name})"
    return f"{result} AS {field.name}", transformed


def transform_schema(table, maps_from_entries=False, bigint_columns=None):
    """Get maps_from_entries expressions for all maps in the given table."""
    if bigquery is None:
        return ["..."]
    schema = sorted(bigquery.Client().get_table(table).schema, key=lambda f: f.name)
    replace = []
    for index, field in enumerate(schema):
        try:
            expr, transformed = transform_field(
                field, maps_from_entries, bigint_columns
            )
        except Exception:
            json_field = json.dumps(field.to_api_repr(), indent=2)
            print(f"problem with field {index}:\n{json_field}", file=sys.stderr)
            raise
        if transformed:
            replace += [expr]
    return replace


def main():
    """Read a table from BigQuery and write it as parquet."""
    args = parser.parse_args()

    # handle --submission-date
    if args.submission_date is not None:
        # --filter "submission_date = DATE 'SUBMISSION_DATE'"
        condition = "submission_date = DATE '" + args.submission_date + "'"
        args.filter.append(condition)
        # --static-partitions submission_date=SUBMISSION_DATE
        args.static_partitions.append("submission_date=" + args.submission_date)
        # --where "submission_date IS NOT NULL"
        if args.where == "TRUE":
            args.where = condition
        else:
            args.where = "(" + args.where + ") AND " + condition

    # Set default --destination-table if it was not provided
    if args.destination_table is None:
        args.destination_table = args.table

    # append table and --static-partitions to destination
    args.destination = "/".join(
        [
            re.sub("^s3://", "s3a://", args.destination).rstrip("/"),
            re.sub("_(v[0-9]+)$", r"/\1", args.destination_table.rsplit(".", 1).pop()),
        ]
        + args.static_partitions
    )

    # convert --static-partitions to a dict
    args.static_partitions = dict(p.split("=", 1) for p in args.static_partitions)

    # remove --static-partitions fields from --partition-by
    args.partition_by = [
        f for f in args.partition_by if f not in args.static_partitions
    ]

    # add --static-partitions fields to --drop
    args.drop += args.static_partitions.keys()

    # convert --filter to a single string
    args.filter = " AND ".join(args.filter)

    if args.maps_from_entries or args.bigint_columns is not None:
        if "." in args.table:
            table_ref = args.table.replace(":", ".")
        else:
            table_ref = f"{args.dataset}.{args.table}"
        args.replace += transform_schema(
            table_ref, args.maps_from_entries, args.bigint_columns
        )

    if args.dry_run:
        replace = f"{args.replace!r}"
        if len(replace) > 60:
            replace = (
                "["
                + ",".join(f"\n{' '*4*5}{expr!r}" for expr in args.replace)
                + f"\n{' '*4*4}]"
            )
        print("spark = SparkSession.builder.appName('export_to_parquet').getOrCreate()")
        print("")
        print(
            "spark.conf.set('spark.sql.sources.partitionOverwriteMode', "
            f"{args.partition_overwrite_mode!r})"
        )
        print("")
        if args.avro_path is not None:
            print(f"df = spark.read.format('avro').load({args.avro_path!r})")
        else:
            print(
                dedent(
                    f"""
                    df = (
                        spark.read.format('bigquery')
                        .option('dataset', {args.dataset!r})
                        .option('table', {args.table!r})
                        .option('filter', {args.filter!r})
                        .option("parallelism", 0)  # let BigQuery storage API decide
                        .load()
                    )
                    """
                ).strip()
            )
        print("")
        print(
            dedent(
                f"""
                df = df.where({args.where!r}).selectExpr(*{args.select!r}).drop(*{args.drop!r})

                for sql in {replace}:
                    value, name = re.fullmatch("(?i)(.*) AS (.*)", sql).groups()
                    df = df.withColumn(name, expr(value))

                (
                    df.write.mode({args.write_mode!r})
                    .partitionBy(*{args.partition_by!r})
                    .parquet({args.destination!r})
                )
                """  # noqa:E501
            ).strip()
        )
    else:
        # delay import to allow --dry-run without spark
        from pyspark.sql import SparkSession
        from pyspark.sql.functions import expr

        if bigquery is None:
            raise bigquery_error

        spark = SparkSession.builder.appName("export_to_parquet").getOrCreate()

        spark.conf.set(
            "spark.sql.sources.partitionOverwriteMode", args.partition_overwrite_mode
        )

        # run spark job from parsed args
        if args.avro_path is not None:
            df = spark.read.format("avro").load(args.avro_path)
        else:
            df = (
                spark.read.format("bigquery")
                .option("dataset", args.dataset)
                .option("table", args.table)
                .option("filter", args.filter)
                .option("parallelism", 0)  # let BigQuery storage API decide
                .load()
            )

        df = df.where(args.where).selectExpr(*args.select).drop(*args.drop)

        for sql in args.replace:
            value, name = re.fullmatch("(?i)(.*) AS (.*)", sql).groups()
            df = df.withColumn(name, expr(value))

        (
            df.write.mode(args.write_mode)
            .partitionBy(*args.partition_by)
            .parquet(args.destination)
        )


if __name__ == "__main__":
    main()
