scripts/generate_dag.py (249 lines of code) (raw):

# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import pathlib import re import subprocess import typing import google.auth import jinja2 from ruamel import yaml CURRENT_PATH = pathlib.Path(__file__).resolve().parent PROJECT_ROOT = CURRENT_PATH.parent DATASETS_PATH = PROJECT_ROOT / "datasets" AIRFLOW_TEMPLATES_PATH = PROJECT_ROOT / "templates" / "airflow" TEMPLATE_PATHS = { "dag": AIRFLOW_TEMPLATES_PATH / "dag.py.jinja2", "task": AIRFLOW_TEMPLATES_PATH / "task.py.jinja2", "license": AIRFLOW_TEMPLATES_PATH / "license_header.py.jinja2", "dag_context": AIRFLOW_TEMPLATES_PATH / "dag_context.py.jinja2", "default_args": AIRFLOW_TEMPLATES_PATH / "default_args.py.jinja2", } DEFAULT_AIRFLOW_VERSION = 2 AIRFLOW_IMPORTS = json.load(open(CURRENT_PATH / "dag_imports.json")) AIRFLOW_VERSIONS = list(AIRFLOW_IMPORTS.keys()) def main( dataset_id: str, pipeline_id: str, env: str, all_pipelines: bool = False, skip_builds: bool = False, async_builds: bool = False, format_code: bool = True, ): if not skip_builds: build_images(dataset_id, env, async_builds) if all_pipelines: for pipeline_dir in list_subdirs(DATASETS_PATH / dataset_id / "pipelines"): generate_pipeline_dag(dataset_id, pipeline_dir.name, env, format_code) else: generate_pipeline_dag(dataset_id, pipeline_id, env, format_code) def generate_pipeline_dag( dataset_id: str, pipeline_id: str, env: str, format_code: bool ): CustomYAMLTags(dataset_id) pipeline_dir = DATASETS_PATH / dataset_id / "pipelines" / pipeline_id config = yaml.load((pipeline_dir / "pipeline.yaml").read_text(), Loader=yaml.Loader) validate_airflow_version_existence_and_value(config) validate_dag_id_existence_and_format(config) dag_contents = generate_dag(config, dataset_id) dag_path = pipeline_dir / f"{pipeline_id}_dag.py" dag_path.touch() write_to_file(dag_contents, dag_path) if format_code: format_python_code(dag_path) copy_files_to_dot_dir( dataset_id, pipeline_id, PROJECT_ROOT / f".{env}", ) print_airflow_variables(dataset_id, dag_contents, env) def generate_dag(config: dict, dataset_id: str) -> str: return jinja2.Template(TEMPLATE_PATHS["dag"].read_text()).render( package_imports=generate_package_imports(config), default_args=generate_default_args(config), dag_context=generate_dag_context(config, dataset_id), tasks=generate_tasks(config), graph_paths=config["dag"]["graph_paths"], ) def generate_package_imports(config: dict) -> str: _airflow_version = airflow_version(config) contents = {"from airflow import DAG"} for task in config["dag"]["tasks"]: contents.add(AIRFLOW_IMPORTS[_airflow_version][task["operator"]]["import"]) return "\n".join(contents) def generate_tasks(config: dict) -> list: _airflow_version = airflow_version(config) contents = [] for task in config["dag"]["tasks"]: contents.append(generate_task_contents(task, _airflow_version)) return contents def generate_default_args(config: dict) -> str: return jinja2.Template(TEMPLATE_PATHS["default_args"].read_text()).render( default_args=dag_init(config)["default_args"] ) def generate_dag_context(config: dict, dataset_id: str) -> str: dag_params = dag_init(config) return jinja2.Template(TEMPLATE_PATHS["dag_context"].read_text()).render( dag_init=dag_params, namespaced_dag_id=namespaced_dag_id(dag_params["dag_id"], dataset_id), ) def generate_task_contents(task: dict, airflow_version: str) -> str: validate_task(task, airflow_version) return jinja2.Template(TEMPLATE_PATHS["task"].read_text()).render( **task, namespaced_operator=AIRFLOW_IMPORTS[airflow_version][task["operator"]]["class"], ) def dag_init(config: dict) -> dict: return config["dag"].get("initialize") or config["dag"].get("init") def airflow_version(config: dict) -> str: return str(config["dag"].get("airflow_version", DEFAULT_AIRFLOW_VERSION)) def namespaced_dag_id(dag_id: str, dataset_id: str) -> str: return f"{dataset_id}.{dag_id}" def validate_airflow_version_existence_and_value(config: dict): if "airflow_version" not in config["dag"]: raise KeyError("Missing required parameter:`dag.airflow_version`") if str(config["dag"]["airflow_version"]) not in AIRFLOW_VERSIONS: raise ValueError("`dag.airflow_version` must be a valid Airflow major version") def validate_dag_id_existence_and_format(config: dict): init = dag_init(config) if not init.get("dag_id"): raise KeyError("Missing required parameter:`dag_id`") dag_id_regex = r"^[a-zA-Z0-9_\.]*$" if not re.match(dag_id_regex, init["dag_id"]): raise ValueError( "`dag_id` must contain only alphanumeric, dot, and underscore characters" ) def validate_task(task: dict, airflow_version: str): if not task.get("operator"): raise KeyError(f"`operator` key must exist in {task}") if not task["operator"] in AIRFLOW_IMPORTS[airflow_version]: raise ValueError( f"`task.operator` must be one of {list(AIRFLOW_IMPORTS[airflow_version].keys())}" ) if not task["args"].get("task_id"): raise KeyError(f"`args.task_id` key must exist in {task}") def list_subdirs(path: pathlib.Path) -> typing.List[pathlib.Path]: """Returns a list of subdirectories""" subdirs = [f for f in path.iterdir() if f.is_dir() and not f.name[0] in (".", "_")] return subdirs def write_to_file(contents: str, filepath: pathlib.Path): license_header = pathlib.Path(TEMPLATE_PATHS["license"]).read_text() + "\n" with open(filepath, "w") as file_: file_.write(license_header + contents.replace(license_header, "")) def format_python_code(target_file: pathlib.Path): subprocess.Popen( f"black -q {target_file}", stdout=subprocess.PIPE, shell=True ).wait() subprocess.check_call(["isort", "--profile", "black", "."], cwd=PROJECT_ROOT) def print_airflow_variables(dataset_id: str, dag_contents: str, env: str): var_regex = r"\{{2}\s*var.json.([a-zA-Z0-9_\.]*)?\s*\}{2}" print( f"\nThe following Airflow variables must be set in" f"\n\n .{env}/datasets/{dataset_id}/pipelines/{dataset_id}_variables.json" "\n\nusing JSON dot notation:" "\n" ) for var in sorted( list(set(re.findall(var_regex, dag_contents))), key=lambda v: v.count(".") ): if var.startswith("json."): var = var.replace("json.", "", 1) elif var.startswith("value."): var = var.replace("value.", "", 1) print(f" - {var}") print() def copy_files_to_dot_dir(dataset_id: str, pipeline_id: str, env_dir: pathlib.Path): source_dir = PROJECT_ROOT / "datasets" / dataset_id / "pipelines" / pipeline_id target_dir = env_dir / "datasets" / dataset_id / "pipelines" target_dir.mkdir(parents=True, exist_ok=True) subprocess.check_call( ["cp", "-rf", str(source_dir), str(target_dir)], cwd=PROJECT_ROOT ) def build_images(dataset_id: str, env: str, async_builds: bool): parent_dir = DATASETS_PATH / dataset_id / "pipelines" / "_images" if not parent_dir.exists(): return image_dirs = copy_image_files_to_dot_dir( dataset_id, parent_dir, PROJECT_ROOT / f".{env}" ) for image_dir in image_dirs: build_and_push_image(dataset_id, image_dir, async_builds) def copy_image_files_to_dot_dir( dataset_id: str, parent_dir: pathlib.Path, env_dir: pathlib.Path ) -> typing.List[pathlib.Path]: target_dir = env_dir / "datasets" / dataset_id / "pipelines" target_dir.mkdir(parents=True, exist_ok=True) subprocess.check_call( ["cp", "-rf", str(parent_dir), str(target_dir)], cwd=PROJECT_ROOT ) return list_subdirs(target_dir / "_images") def build_and_push_image( dataset_id: str, image_dir: pathlib.Path, async_builds: bool = False ): image_name = f"{dataset_id}__{image_dir.name}" command = [ "gcloud", "builds", "submit", "--async", "--tag", f"gcr.io/{gcp_project_id()}/{image_name}", ] if not async_builds: command.remove("--async") # gcloud builds submit --tag gcr.io/PROJECT_ID/IMAGE_NAME subprocess.check_call(command, cwd=image_dir) def gcp_project_id() -> str: _, project_id = google.auth.default() return project_id class CustomYAMLTags(yaml.YAMLObject): def __init__(self, dataset): self.dataset = dataset yaml.add_constructor("!IMAGE", self.image_constructor) def image_constructor(self, loader, node): value = loader.construct_scalar(node) value = f"gcr.io/{{{{ var.value.gcp_project }}}}/{self.dataset}__{value}" return value if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate Terraform infra code for BigQuery datasets" ) parser.add_argument( "-d", "--dataset", required=True, type=str, dest="dataset", help="The directory name of the dataset.", ) parser.add_argument( "-p", "--pipeline", type=str, dest="pipeline", help="The directory name of the pipeline", ) parser.add_argument( "-e", "--env", type=str, default="dev", dest="env", help="The stage used for the resources: dev|staging|prod", ) parser.add_argument( "--all-pipelines", required=False, dest="all_pipelines", action="store_true" ) parser.add_argument( "--skip-builds", required=False, dest="skip_builds", action="store_true" ) parser.add_argument( "--async-builds", required=False, dest="async_builds", action="store_false" ) args = parser.parse_args() main( args.dataset, args.pipeline, args.env, args.all_pipelines, args.skip_builds, args.async_builds, )