"""Query schema."""

import json
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional

import attr
import yaml
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import SchemaField

from .. import dryrun

SCHEMA_FILE = "schema.yaml"


@attr.s(auto_attribs=True)
class Schema:
    """Query schema representation and helpers."""

    schema: Dict[str, Any]
    _type_mapping: Dict[str, str] = {
        "INT64": "INTEGER",
        "BOOL": "BOOLEAN",
        "FLOAT64": "FLOAT",
    }

    @classmethod
    def from_query_file(cls, query_file: Path, *args, **kwargs):
        """Create schema from a query file."""
        if not query_file.is_file() or query_file.suffix != ".sql":
            raise Exception(f"{query_file} is not a valid SQL file.")

        schema = dryrun.DryRun(str(query_file), *args, **kwargs).get_schema()
        return cls(schema)

    @classmethod
    def from_schema_file(cls, schema_file: Path):
        """Create schema from a yaml schema file."""
        if not schema_file.is_file() or schema_file.suffix != ".yaml":
            raise Exception(f"{schema_file} is not a valid YAML schema file.")

        with open(schema_file) as file:
            schema = yaml.load(file, Loader=yaml.FullLoader)
            return cls(schema)

    @classmethod
    def empty(cls):
        """Create an empty schema."""
        return cls({"fields": []})

    @classmethod
    def from_json(cls, json_schema):
        """Create schema from JSON object."""
        return cls(json_schema)

    @classmethod
    def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs):
        """Get the schema for a BigQuery table."""
        query = f"SELECT * FROM `{project}.{dataset}.{table}`"

        if partitioned_by:
            query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')"

        try:
            return cls(
                dryrun.DryRun(
                    os.path.join(project, dataset, table, "query.sql"),
                    query,
                    project=project,
                    dataset=dataset,
                    table=table,
                    *args,
                    **kwargs,
                ).get_schema()
            )
        except Exception as e:
            print(f"Cannot get schema for {project}.{dataset}.{table}: {e}")
            return cls({"fields": []})

    def deploy(self, destination_table: str) -> bigquery.Table:
        """Deploy the schema to BigQuery named after destination_table."""
        client = bigquery.Client()
        tmp_schema_file = NamedTemporaryFile()
        self.to_json_file(Path(tmp_schema_file.name))
        bigquery_schema = client.schema_from_json(tmp_schema_file.name)

        try:
            # destination table already exists, update schema
            table = client.get_table(destination_table)
            table.schema = bigquery_schema
            return client.update_table(table, ["schema"])
        except NotFound:
            table = bigquery.Table(destination_table, schema=bigquery_schema)
            return client.create_table(table)

    def merge(
        self,
        other: "Schema",
        exclude: Optional[List[str]] = None,
        add_missing_fields=True,
        attributes: Optional[List[str]] = None,
        ignore_incompatible_fields: bool = False,
        ignore_missing_fields: bool = False,
    ):
        """Merge another schema into the schema."""
        if "fields" in other.schema and "fields" in self.schema:
            self._traverse(
                "root",
                self.schema["fields"],
                other.schema["fields"],
                update=True,
                exclude=exclude,
                add_missing_fields=add_missing_fields,
                attributes=attributes,
                ignore_incompatible_fields=ignore_incompatible_fields,
                ignore_missing_fields=ignore_missing_fields,
            )

    def equal(self, other: "Schema") -> bool:
        """Compare to another schema."""
        try:
            self._traverse(
                "root", self.schema["fields"], other.schema["fields"], update=False
            )
            self._traverse(
                "root", other.schema["fields"], self.schema["fields"], update=False
            )
        except Exception as e:
            print(e)
            return False

        return True

    def compatible(self, other: "Schema") -> bool:
        """
        Check if schema is compatible with another schema.

        If there is a field missing in the schema that is part of the "other" schema,
        the schemas are still compatible. However, if there are fields missing in the
        "other" schema they are not compatible since, e.g. inserting data into the "other"
        schema that follows this schema would fail.
        """
        try:
            self._traverse(
                "root",
                self.schema["fields"],
                other.schema["fields"],
                update=False,
                ignore_missing_fields=True,
            )
            self._traverse(
                "root",
                other.schema["fields"],
                self.schema["fields"],
                update=False,
                ignore_missing_fields=False,
            )
        except Exception as e:
            print(e)
            return False

        return True

    @staticmethod
    def _node_with_mode(node):
        """Add default value for mode to node."""
        if "mode" in node:
            return node
        return {"mode": "NULLABLE", **node}

    def _traverse(
        self,
        prefix,
        columns,
        other_columns,
        update=False,
        add_missing_fields=True,
        ignore_missing_fields=False,
        exclude=None,
        attributes=None,
        ignore_incompatible_fields=False,
    ):
        """Traverses two schemas for validation and optionally updates the first schema."""
        nodes = {n["name"]: Schema._node_with_mode(n) for n in columns}
        other_nodes = {
            n["name"]: Schema._node_with_mode(n)
            for n in other_columns
            if exclude is None or n["name"] not in exclude
        }

        for node_name, node in other_nodes.items():
            field_path = node["name"] + (".[]" if node["mode"] == "REPEATED" else "")
            dtype = node["type"]

            if node_name in nodes:
                # node exists in schema, update attributes where necessary
                for node_attr_key, node_attr_value in node.items():
                    if attributes and node_attr_key not in attributes:
                        continue

                    if node_attr_key == "type":
                        # sometimes types have multiple names (e.g. INT64 and INTEGER)
                        # make it consistent here
                        node_attr_value = self._type_mapping.get(
                            node_attr_value, node_attr_value
                        )
                        nodes[node_name][node_attr_key] = self._type_mapping.get(
                            nodes[node_name][node_attr_key],
                            nodes[node_name][node_attr_key],
                        )

                    if node_attr_key not in nodes[node_name]:
                        if update:
                            # add field attributes if not exists in schema
                            nodes[node_name][node_attr_key] = node_attr_value
                            # Netlify has a problem starting 2022-03-07 where lots of
                            # logging slows down builds to the point where our builds hit
                            # the time limit and fail (bug 1761292), and this print
                            # statement accounts for 84% of our build logging.
                            # TODO: Uncomment this print when Netlify fixes the problem.
                            # print(
                            #    f"Attribute {node_attr_key} added to {prefix}.{field_path}"
                            # )
                        else:
                            if node_attr_key == "description":
                                print(
                                    "Warning: descriptions for "
                                    f"{prefix}.{field_path} differ"
                                )
                            else:
                                if not ignore_incompatible_fields:
                                    raise Exception(
                                        f"{node_attr_key} missing in {prefix}.{field_path}"
                                    )
                    elif nodes[node_name][node_attr_key] != node_attr_value:
                        # check field attribute diffs
                        if node_attr_key == "description":
                            # overwrite descripton for the "other" schema
                            print(
                                f"Warning: descriptions for {prefix}.{field_path} differ."
                            )
                        elif node_attr_key != "fields":
                            if not ignore_incompatible_fields:
                                raise Exception(
                                    f"Cannot merge schemas. {node_attr_key} attributes "
                                    f"for {prefix}.{field_path} are incompatible"
                                )

                if dtype == "RECORD" and nodes[node_name]["type"] == "RECORD":
                    # keep traversing nested fields
                    self._traverse(
                        f"{prefix}.{field_path}",
                        nodes[node_name]["fields"],
                        node["fields"],
                        update=update,
                        add_missing_fields=add_missing_fields,
                        ignore_missing_fields=ignore_missing_fields,
                        attributes=attributes,
                        ignore_incompatible_fields=ignore_incompatible_fields,
                    )
            else:
                if update and add_missing_fields:
                    # node does not exist in schema, add to schema
                    columns.append(node.copy())
                    print(f"Field {node_name} added to {prefix}")
                else:
                    if not ignore_missing_fields:
                        raise Exception(
                            f"Field {prefix}.{field_path} is missing in schema"
                        )

    def to_yaml_file(self, yaml_path: Path):
        """Write schema to the YAML file path."""
        with open(yaml_path, "w") as out:
            yaml.dump(self.schema, out, default_flow_style=False, sort_keys=False)

    def to_json_file(self, json_path: Path):
        """Write schema to the JSON file path."""
        with open(json_path, "w") as out:
            json.dump(self.schema["fields"], out, indent=2)

    def to_json(self):
        """Return the schema data as JSON."""
        return json.dumps(self.schema)

    def to_bigquery_schema(self) -> List[SchemaField]:
        """Get the BigQuery representation of the schema."""
        return [SchemaField.from_api_repr(field) for field in self.schema["fields"]]

    @classmethod
    def from_bigquery_schema(cls, fields: List[SchemaField]) -> "Schema":
        """Construct a Schema from the BigQuery representation."""
        return cls({"fields": [field.to_api_repr() for field in fields]})

    def generate_compatible_select_expression(
        self,
        target_schema: "Schema",
        fields_to_remove: Optional[Iterable[str]] = None,
        unnest_structs: bool = False,
        max_unnest_depth: int = 0,
        unnest_allowlist: Optional[Iterable[str]] = None,
    ) -> str:
        """Generate the select expression for the source schema based on the target schema.

        The output will include all fields of the target schema in the same order of the target.
        Any fields that are missing in the source schema are set to NULL.

        :param target_schema: The schema to coerce the current schema to.
        :param fields_to_remove: Given fields are removed from the output expression. Expressed as a
            list of strings with `.` separating each level of nesting, e.g. record_name.field.
        :param unnest_structs: If true, all record fields are expressed as structs with all nested
            fields explicitly listed. This allows the expression to be compatible even if the
            source schemas get new fields added. Otherwise, records are only unnested if they
            do not match the target schema.
        :param max_unnest_depth: Maximum level of struct nesting to explicitly unnest in
            the expression.
        :param unnest_allowlist: If set, only the given top-level structs are unnested.
        """

        def _type_info(node):
            """Determine the BigQuery type information from Schema object field."""
            dtype = node["type"]
            if dtype == "RECORD":
                dtype = (
                    "STRUCT<"
                    + ", ".join(
                        f"`{field['name']}` {_type_info(field)}"
                        for field in node["fields"]
                    )
                    + ">"
                )
            elif dtype == "FLOAT":
                dtype = "FLOAT64"
            if node.get("mode") == "REPEATED":
                return f"ARRAY<{dtype}>"
            return dtype

        def recurse_fields(
            _source_schema_nodes: List[Dict],
            _target_schema_nodes: List[Dict],
            path=None,
        ) -> str:
            if path is None:
                path = []

            select_expr = []
            source_schema_nodes = {n["name"]: n for n in _source_schema_nodes}
            target_schema_nodes = {n["name"]: n for n in _target_schema_nodes}

            # iterate through fields
            for node_name, node in target_schema_nodes.items():
                dtype = node["type"]
                node_path = path + [node_name]
                node_path_str = ".".join(node_path)

                if node_name in source_schema_nodes:  # field exists in app schema
                    # field matches, can query as-is
                    if node == source_schema_nodes[node_name] and (
                        # don't need to unnest scalar
                        dtype != "RECORD"
                        or not unnest_structs
                        # reached max record depth to unnest
                        or len(node_path) > max_unnest_depth > 0
                        # field not in unnest allowlist
                        or (
                            unnest_allowlist is not None
                            and node_path[0] not in unnest_allowlist
                        )
                    ):
                        if (
                            fields_to_remove is None
                            or node_path_str not in fields_to_remove
                        ):
                            select_expr.append(node_path_str)
                    elif (
                        dtype == "RECORD"
                    ):  # for nested fields, recursively generate select expression
                        if (
                            node.get("mode", None) == "REPEATED"
                        ):  # unnest repeated record
                            select_expr.append(
                                f"""
                                    ARRAY(
                                        SELECT
                                            STRUCT(
                                                {recurse_fields(
                                                    source_schema_nodes[node_name]['fields'],
                                                    node['fields'],
                                                    [node_name],
                                                )}
                                            )
                                        FROM UNNEST({node_path_str}) AS `{node_name}`
                                    ) AS `{node_name}`
                                """
                            )
                        else:  # select struct fields
                            select_expr.append(
                                f"""
                                    STRUCT(
                                        {recurse_fields(
                                            source_schema_nodes[node_name]['fields'],
                                            node['fields'],
                                            node_path,
                                        )}
                                    ) AS `{node_name}`
                                """
                            )
                    else:  # scalar value doesn't match, e.g. different types
                        select_expr.append(
                            f"CAST(NULL AS {_type_info(node)}) AS `{node_name}`"
                        )
                else:  # field not found in source schema
                    select_expr.append(
                        f"CAST(NULL AS {_type_info(node)}) AS `{node_name}`"
                    )

            return ", ".join(select_expr)

        return recurse_fields(
            self.schema["fields"],
            target_schema.schema["fields"],
        )

    def generate_select_expression(
        self,
        remove_fields: Optional[Iterable[str]] = None,
        unnest_structs: bool = False,
        max_unnest_depth: int = 0,
        unnest_allowlist: Optional[Iterable[str]] = None,
    ) -> str:
        """Generate the select expression for the schema which includes each field."""
        return self.generate_compatible_select_expression(
            self,
            remove_fields,
            unnest_structs,
            max_unnest_depth,
            unnest_allowlist,
        )
