bigquery_etl/query_scheduling/dag_collection.py (123 lines of code) (raw):

"""Represents a collection of configured Airflow DAGs.""" from collections import defaultdict from functools import partial from itertools import groupby from multiprocessing import get_context, set_start_method from operator import attrgetter from pathlib import Path import yaml from black import FileMode, format_file_contents from bigquery_etl.query_scheduling.dag import Dag, InvalidDag, PublicDataJsonDag class DagCollection: """Representation of all configured DAGs.""" def __init__(self, dags): """Instantiate DAGs.""" self.dags = dags self.dags_by_name = {dag.name: dag for dag in dags} @classmethod def from_dict(cls, d): """ Parse DAG configurations from a dict and create new instances. Expected dict format: { "bqetl_dag_name1": { "schedule_interval": string, "default_args": { "owner": string, "start_date": "YYYY-MM-DD", ... } }, "bqetl_dag_name2": { "schedule_interval": string, "default_args": { "owner": string, "start_date": "YYYY-MM-DD", ... } }, ... } """ if d is None: return cls([]) dags = [Dag.from_dict({k: v}) for k, v in d.items()] return cls(dags) @classmethod def from_file(cls, config_file): """Instantiate DAGs based on the provided configuration file.""" with open(config_file, "r") as yaml_stream: dags_config = yaml.safe_load(yaml_stream) return DagCollection.from_dict(dags_config) def dag_by_name(self, name): """Return the DAG with the provided name.""" return self.dags_by_name.get(name) def task_for_table(self, project, dataset, table): """Return the task that schedules the query for the provided table.""" for dag in self.dags: for task in dag.tasks: if ( project == task.project and dataset == task.dataset and table == f"{task.table}_{task.version}" and not task.is_dq_check and not task.is_bigeye_check ): return task return None def fail_checks_task_for_table(self, project, dataset, table): """Return the task that schedules the checks for the provided table.""" for dag in self.dags: for task in dag.tasks: if ( project == task.project and dataset == task.dataset and table == f"{task.table}_{task.version}" and task.is_dq_check and task.is_dq_check_fail ): return task return None def fail_bigeye_checks_task_for_table(self, project, dataset, table): """Return the task that schedules the BigEye checks for the provided table.""" for dag in self.dags: for task in dag.tasks: if ( project == task.project and dataset == task.dataset and table == f"{task.table}_{task.version}" and task.is_bigeye_check ): return task return None def with_tasks(self, tasks): """Assign tasks to their corresponding DAGs.""" public_data_json_dag = None get_dag_name = attrgetter("dag_name") for dag_name, group in groupby(sorted(tasks, key=get_dag_name), get_dag_name): dag = self.dag_by_name(dag_name) if dag is None: raise InvalidDag( f"DAG {dag_name} does not exist in dags.yaml " f"but used in task definition {next(group).task_name}." ) dag.add_tasks(list(group)) public_json_tasks = [ task for task in tasks if task.public_json and not (task.is_dq_check or task.is_bigeye_check) ] if public_json_tasks: for dag in self.dags: if dag.__class__ == PublicDataJsonDag: public_data_json_dag = dag if public_data_json_dag: public_data_json_dag.add_export_tasks(public_json_tasks, self) return self def get_task_downstream_dependencies(self, task): """Return all direct downstream dependencies of the task.""" # Cache the downstream dependencies for faster lookups. if not hasattr(self, "_downstream_dependencies"): downstream_dependencies = defaultdict(list) for dag in self.dags: for _task in dag.tasks: _task.with_upstream_dependencies(self) for upstream_dependency in ( _task.depends_on + _task.upstream_dependencies ): downstream_dependencies[upstream_dependency.task_key].append( _task.to_ref(self) ) self._downstream_dependencies = downstream_dependencies return self._downstream_dependencies[task.task_key] def dag_to_airflow(self, output_dir, dag): """Generate the Airflow DAG representation for the provided DAG.""" output_file = Path(output_dir) / (dag.name + ".py") try: formatted_dag = format_file_contents( dag.to_airflow_dag(), fast=False, mode=FileMode() ) output_file.write_text(formatted_dag) except InvalidDag as e: print(e) def to_airflow_dags(self, output_dir, dag_to_generate=None): """Write DAG representation as Airflow dags to file.""" # https://pythonspeed.com/articles/python-multiprocessing/ # when running tests on CI that call this function, we need # to create a custom pool to prevent processes from getting stuck # Generate a single DAG: if dag_to_generate is not None: dag_to_generate.with_upstream_dependencies(self) dag_to_generate.with_downstream_dependencies(self) self.dag_to_airflow(output_dir, dag_to_generate) return # Generate all DAGs: try: set_start_method("spawn") except Exception: pass for dag in self.dags: dag.with_upstream_dependencies(self) dag.with_downstream_dependencies(self) to_airflow_dag = partial(self.dag_to_airflow, output_dir) with get_context("spawn").Pool(8) as p: p.map(to_airflow_dag, self.dags)