bigquery_etl/query_scheduling/dag_collection.py (123 lines of code) (raw):
"""Represents a collection of configured Airflow DAGs."""
from collections import defaultdict
from functools import partial
from itertools import groupby
from multiprocessing import get_context, set_start_method
from operator import attrgetter
from pathlib import Path
import yaml
from black import FileMode, format_file_contents
from bigquery_etl.query_scheduling.dag import Dag, InvalidDag, PublicDataJsonDag
class DagCollection:
"""Representation of all configured DAGs."""
def __init__(self, dags):
"""Instantiate DAGs."""
self.dags = dags
self.dags_by_name = {dag.name: dag for dag in dags}
@classmethod
def from_dict(cls, d):
"""
Parse DAG configurations from a dict and create new instances.
Expected dict format:
{
"bqetl_dag_name1": {
"schedule_interval": string,
"default_args": {
"owner": string,
"start_date": "YYYY-MM-DD",
...
}
},
"bqetl_dag_name2": {
"schedule_interval": string,
"default_args": {
"owner": string,
"start_date": "YYYY-MM-DD",
...
}
},
...
}
"""
if d is None:
return cls([])
dags = [Dag.from_dict({k: v}) for k, v in d.items()]
return cls(dags)
@classmethod
def from_file(cls, config_file):
"""Instantiate DAGs based on the provided configuration file."""
with open(config_file, "r") as yaml_stream:
dags_config = yaml.safe_load(yaml_stream)
return DagCollection.from_dict(dags_config)
def dag_by_name(self, name):
"""Return the DAG with the provided name."""
return self.dags_by_name.get(name)
def task_for_table(self, project, dataset, table):
"""Return the task that schedules the query for the provided table."""
for dag in self.dags:
for task in dag.tasks:
if (
project == task.project
and dataset == task.dataset
and table == f"{task.table}_{task.version}"
and not task.is_dq_check
and not task.is_bigeye_check
):
return task
return None
def fail_checks_task_for_table(self, project, dataset, table):
"""Return the task that schedules the checks for the provided table."""
for dag in self.dags:
for task in dag.tasks:
if (
project == task.project
and dataset == task.dataset
and table == f"{task.table}_{task.version}"
and task.is_dq_check
and task.is_dq_check_fail
):
return task
return None
def fail_bigeye_checks_task_for_table(self, project, dataset, table):
"""Return the task that schedules the BigEye checks for the provided table."""
for dag in self.dags:
for task in dag.tasks:
if (
project == task.project
and dataset == task.dataset
and table == f"{task.table}_{task.version}"
and task.is_bigeye_check
):
return task
return None
def with_tasks(self, tasks):
"""Assign tasks to their corresponding DAGs."""
public_data_json_dag = None
get_dag_name = attrgetter("dag_name")
for dag_name, group in groupby(sorted(tasks, key=get_dag_name), get_dag_name):
dag = self.dag_by_name(dag_name)
if dag is None:
raise InvalidDag(
f"DAG {dag_name} does not exist in dags.yaml "
f"but used in task definition {next(group).task_name}."
)
dag.add_tasks(list(group))
public_json_tasks = [
task
for task in tasks
if task.public_json and not (task.is_dq_check or task.is_bigeye_check)
]
if public_json_tasks:
for dag in self.dags:
if dag.__class__ == PublicDataJsonDag:
public_data_json_dag = dag
if public_data_json_dag:
public_data_json_dag.add_export_tasks(public_json_tasks, self)
return self
def get_task_downstream_dependencies(self, task):
"""Return all direct downstream dependencies of the task."""
# Cache the downstream dependencies for faster lookups.
if not hasattr(self, "_downstream_dependencies"):
downstream_dependencies = defaultdict(list)
for dag in self.dags:
for _task in dag.tasks:
_task.with_upstream_dependencies(self)
for upstream_dependency in (
_task.depends_on + _task.upstream_dependencies
):
downstream_dependencies[upstream_dependency.task_key].append(
_task.to_ref(self)
)
self._downstream_dependencies = downstream_dependencies
return self._downstream_dependencies[task.task_key]
def dag_to_airflow(self, output_dir, dag):
"""Generate the Airflow DAG representation for the provided DAG."""
output_file = Path(output_dir) / (dag.name + ".py")
try:
formatted_dag = format_file_contents(
dag.to_airflow_dag(), fast=False, mode=FileMode()
)
output_file.write_text(formatted_dag)
except InvalidDag as e:
print(e)
def to_airflow_dags(self, output_dir, dag_to_generate=None):
"""Write DAG representation as Airflow dags to file."""
# https://pythonspeed.com/articles/python-multiprocessing/
# when running tests on CI that call this function, we need
# to create a custom pool to prevent processes from getting stuck
# Generate a single DAG:
if dag_to_generate is not None:
dag_to_generate.with_upstream_dependencies(self)
dag_to_generate.with_downstream_dependencies(self)
self.dag_to_airflow(output_dir, dag_to_generate)
return
# Generate all DAGs:
try:
set_start_method("spawn")
except Exception:
pass
for dag in self.dags:
dag.with_upstream_dependencies(self)
dag.with_downstream_dependencies(self)
to_airflow_dag = partial(self.dag_to_airflow, output_dir)
with get_context("spawn").Pool(8) as p:
p.map(to_airflow_dag, self.dags)