# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import pathlib
import re
import subprocess
import typing

import google.auth
import jinja2
from ruamel import yaml

CURRENT_PATH = pathlib.Path(__file__).resolve().parent
PROJECT_ROOT = CURRENT_PATH.parent
DATASETS_PATH = PROJECT_ROOT / "datasets"
AIRFLOW_TEMPLATES_PATH = PROJECT_ROOT / "templates" / "airflow"

TEMPLATE_PATHS = {
    "dag": AIRFLOW_TEMPLATES_PATH / "dag.py.jinja2",
    "task": AIRFLOW_TEMPLATES_PATH / "task.py.jinja2",
    "license": AIRFLOW_TEMPLATES_PATH / "license_header.py.jinja2",
    "dag_context": AIRFLOW_TEMPLATES_PATH / "dag_context.py.jinja2",
    "default_args": AIRFLOW_TEMPLATES_PATH / "default_args.py.jinja2",
}

DEFAULT_AIRFLOW_VERSION = 2
AIRFLOW_IMPORTS = json.load(open(CURRENT_PATH / "dag_imports.json"))
AIRFLOW_VERSIONS = list(AIRFLOW_IMPORTS.keys())


def main(
    dataset_id: str,
    pipeline_id: str,
    env: str,
    all_pipelines: bool = False,
    skip_builds: bool = False,
    async_builds: bool = False,
    format_code: bool = True,
):
    if not skip_builds:
        build_images(dataset_id, env, async_builds)

    if all_pipelines:
        for pipeline_dir in list_subdirs(DATASETS_PATH / dataset_id / "pipelines"):
            generate_pipeline_dag(dataset_id, pipeline_dir.name, env, format_code)
    else:
        generate_pipeline_dag(dataset_id, pipeline_id, env, format_code)


def generate_pipeline_dag(
    dataset_id: str, pipeline_id: str, env: str, format_code: bool
):
    CustomYAMLTags(dataset_id)
    pipeline_dir = DATASETS_PATH / dataset_id / "pipelines" / pipeline_id
    config = yaml.load((pipeline_dir / "pipeline.yaml").read_text(), Loader=yaml.Loader)
    validate_airflow_version_existence_and_value(config)
    validate_dag_id_existence_and_format(config)
    dag_contents = generate_dag(config, dataset_id)

    dag_path = pipeline_dir / f"{pipeline_id}_dag.py"
    dag_path.touch()
    write_to_file(dag_contents, dag_path)

    if format_code:
        format_python_code(dag_path)

    copy_files_to_dot_dir(
        dataset_id,
        pipeline_id,
        PROJECT_ROOT / f".{env}",
    )

    print_airflow_variables(dataset_id, dag_contents, env)


def generate_dag(config: dict, dataset_id: str) -> str:
    return jinja2.Template(TEMPLATE_PATHS["dag"].read_text()).render(
        package_imports=generate_package_imports(config),
        default_args=generate_default_args(config),
        dag_context=generate_dag_context(config, dataset_id),
        tasks=generate_tasks(config),
        graph_paths=config["dag"]["graph_paths"],
    )


def generate_package_imports(config: dict) -> str:
    _airflow_version = airflow_version(config)
    contents = {"from airflow import DAG"}
    for task in config["dag"]["tasks"]:
        contents.add(AIRFLOW_IMPORTS[_airflow_version][task["operator"]]["import"])
    return "\n".join(contents)


def generate_tasks(config: dict) -> list:
    _airflow_version = airflow_version(config)
    contents = []
    for task in config["dag"]["tasks"]:
        contents.append(generate_task_contents(task, _airflow_version))
    return contents


def generate_default_args(config: dict) -> str:
    return jinja2.Template(TEMPLATE_PATHS["default_args"].read_text()).render(
        default_args=dag_init(config)["default_args"]
    )


def generate_dag_context(config: dict, dataset_id: str) -> str:
    dag_params = dag_init(config)
    return jinja2.Template(TEMPLATE_PATHS["dag_context"].read_text()).render(
        dag_init=dag_params,
        namespaced_dag_id=namespaced_dag_id(dag_params["dag_id"], dataset_id),
    )


def generate_task_contents(task: dict, airflow_version: str) -> str:
    validate_task(task, airflow_version)
    return jinja2.Template(TEMPLATE_PATHS["task"].read_text()).render(
        **task,
        namespaced_operator=AIRFLOW_IMPORTS[airflow_version][task["operator"]]["class"],
    )


def dag_init(config: dict) -> dict:
    return config["dag"].get("initialize") or config["dag"].get("init")


def airflow_version(config: dict) -> str:
    return str(config["dag"].get("airflow_version", DEFAULT_AIRFLOW_VERSION))


def namespaced_dag_id(dag_id: str, dataset_id: str) -> str:
    return f"{dataset_id}.{dag_id}"


def validate_airflow_version_existence_and_value(config: dict):
    if "airflow_version" not in config["dag"]:
        raise KeyError("Missing required parameter:`dag.airflow_version`")

    if str(config["dag"]["airflow_version"]) not in AIRFLOW_VERSIONS:
        raise ValueError("`dag.airflow_version` must be a valid Airflow major version")


def validate_dag_id_existence_and_format(config: dict):
    init = dag_init(config)
    if not init.get("dag_id"):
        raise KeyError("Missing required parameter:`dag_id`")

    dag_id_regex = r"^[a-zA-Z0-9_\.]*$"
    if not re.match(dag_id_regex, init["dag_id"]):
        raise ValueError(
            "`dag_id` must contain only alphanumeric, dot, and underscore characters"
        )


def validate_task(task: dict, airflow_version: str):
    if not task.get("operator"):
        raise KeyError(f"`operator` key must exist in {task}")

    if not task["operator"] in AIRFLOW_IMPORTS[airflow_version]:
        raise ValueError(
            f"`task.operator` must be one of {list(AIRFLOW_IMPORTS[airflow_version].keys())}"
        )

    if not task["args"].get("task_id"):
        raise KeyError(f"`args.task_id` key must exist in {task}")


def list_subdirs(path: pathlib.Path) -> typing.List[pathlib.Path]:
    """Returns a list of subdirectories"""
    subdirs = [f for f in path.iterdir() if f.is_dir() and not f.name[0] in (".", "_")]
    return subdirs


def write_to_file(contents: str, filepath: pathlib.Path):
    license_header = pathlib.Path(TEMPLATE_PATHS["license"]).read_text() + "\n"
    with open(filepath, "w") as file_:
        file_.write(license_header + contents.replace(license_header, ""))


def format_python_code(target_file: pathlib.Path):
    subprocess.Popen(
        f"black -q {target_file}", stdout=subprocess.PIPE, shell=True
    ).wait()
    subprocess.check_call(["isort", "--profile", "black", "."], cwd=PROJECT_ROOT)


def print_airflow_variables(dataset_id: str, dag_contents: str, env: str):
    var_regex = r"\{{2}\s*var.json.([a-zA-Z0-9_\.]*)?\s*\}{2}"
    print(
        f"\nThe following Airflow variables must be set in"
        f"\n\n  .{env}/datasets/{dataset_id}/pipelines/{dataset_id}_variables.json"
        "\n\nusing JSON dot notation:"
        "\n"
    )
    for var in sorted(
        list(set(re.findall(var_regex, dag_contents))), key=lambda v: v.count(".")
    ):
        if var.startswith("json."):
            var = var.replace("json.", "", 1)
        elif var.startswith("value."):
            var = var.replace("value.", "", 1)
        print(f"  - {var}")
    print()


def copy_files_to_dot_dir(dataset_id: str, pipeline_id: str, env_dir: pathlib.Path):
    source_dir = PROJECT_ROOT / "datasets" / dataset_id / "pipelines" / pipeline_id
    target_dir = env_dir / "datasets" / dataset_id / "pipelines"
    target_dir.mkdir(parents=True, exist_ok=True)
    subprocess.check_call(
        ["cp", "-rf", str(source_dir), str(target_dir)], cwd=PROJECT_ROOT
    )


def build_images(dataset_id: str, env: str, async_builds: bool):
    parent_dir = DATASETS_PATH / dataset_id / "pipelines" / "_images"
    if not parent_dir.exists():
        return

    image_dirs = copy_image_files_to_dot_dir(
        dataset_id, parent_dir, PROJECT_ROOT / f".{env}"
    )
    for image_dir in image_dirs:
        build_and_push_image(dataset_id, image_dir, async_builds)


def copy_image_files_to_dot_dir(
    dataset_id: str, parent_dir: pathlib.Path, env_dir: pathlib.Path
) -> typing.List[pathlib.Path]:
    target_dir = env_dir / "datasets" / dataset_id / "pipelines"
    target_dir.mkdir(parents=True, exist_ok=True)
    subprocess.check_call(
        ["cp", "-rf", str(parent_dir), str(target_dir)], cwd=PROJECT_ROOT
    )

    return list_subdirs(target_dir / "_images")


def build_and_push_image(
    dataset_id: str, image_dir: pathlib.Path, async_builds: bool = False
):
    image_name = f"{dataset_id}__{image_dir.name}"
    command = [
        "gcloud",
        "builds",
        "submit",
        "--async",
        "--tag",
        f"gcr.io/{gcp_project_id()}/{image_name}",
    ]

    if not async_builds:
        command.remove("--async")

    # gcloud builds submit --tag gcr.io/PROJECT_ID/IMAGE_NAME
    subprocess.check_call(command, cwd=image_dir)


def gcp_project_id() -> str:
    _, project_id = google.auth.default()
    return project_id


class CustomYAMLTags(yaml.YAMLObject):
    def __init__(self, dataset):
        self.dataset = dataset
        yaml.add_constructor("!IMAGE", self.image_constructor)

    def image_constructor(self, loader, node):
        value = loader.construct_scalar(node)
        value = f"gcr.io/{{{{ var.value.gcp_project }}}}/{self.dataset}__{value}"
        return value


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate Terraform infra code for BigQuery datasets"
    )
    parser.add_argument(
        "-d",
        "--dataset",
        required=True,
        type=str,
        dest="dataset",
        help="The directory name of the dataset.",
    )
    parser.add_argument(
        "-p",
        "--pipeline",
        type=str,
        dest="pipeline",
        help="The directory name of the pipeline",
    )
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        default="dev",
        dest="env",
        help="The stage used for the resources: dev|staging|prod",
    )
    parser.add_argument(
        "--all-pipelines", required=False, dest="all_pipelines", action="store_true"
    )
    parser.add_argument(
        "--skip-builds", required=False, dest="skip_builds", action="store_true"
    )
    parser.add_argument(
        "--async-builds", required=False, dest="async_builds", action="store_false"
    )

    args = parser.parse_args()

    main(
        args.dataset,
        args.pipeline,
        args.env,
        args.all_pipelines,
        args.skip_builds,
        args.async_builds,
    )
