bigquery_etl/query_scheduling/dag.py (197 lines of code) (raw):
"""Represents an Airflow DAG."""
from typing import List, Optional, Set
import attr
import cattrs
from jinja2 import Environment, PackageLoader
from bigquery_etl.config import ConfigLoader
from bigquery_etl.query_scheduling import formatters
from bigquery_etl.query_scheduling.task import Task, TaskRef
from bigquery_etl.query_scheduling.utils import (
is_date_string,
is_email_or_github_identity,
is_schedule_interval,
is_valid_dag_name,
schedule_interval_delta,
validate_timedelta_string,
)
AIRFLOW_DAG_TEMPLATE = "airflow_dag.j2"
PUBLIC_DATA_JSON_DAG_TEMPLATE = "public_data_json_airflow_dag.j2"
PUBLIC_DATA_JSON_DAG = "bqetl_public_data_json"
CONFIDENTIAL_TAG = "triage/confidential"
class DagParseException(Exception):
"""Raised when DAG config is invalid."""
def __init__(self, message):
"""Throw DagParseException."""
message = f"""
{message}
Expected yaml format:
name:
schedule_interval: string,
default_args:
owner: string
start_date: 'YYYY-MM-DD'
...
"""
super(DagParseException, self).__init__(message)
class InvalidDag(Exception):
"""Raised when the resulting DAG is invalid."""
pass
@attr.s(auto_attribs=True)
class DagDefaultArgs:
"""
Representation of Airflow DAG default_args.
Uses attrs to simplify the class definition and provide validation.
Docs: https://www.attrs.org
"""
owner: str = attr.ib()
start_date: str = attr.ib()
end_date: Optional[str] = attr.ib(None)
email: List[str] = attr.ib([])
depends_on_past: bool = attr.ib(False)
retry_delay: str = attr.ib("30m")
email_on_failure: bool = attr.ib(True)
email_on_retry: bool = attr.ib(True)
retries: int = attr.ib(2)
max_active_tis_per_dag: Optional[int] = attr.ib(None)
@owner.validator
def validate_owner(self, attribute, value):
"""Check that owner is a valid email address."""
if not is_email_or_github_identity(value):
raise ValueError(
f"Invalid email or github identity for DAG owner: {value}."
)
@email.validator
def validate_email(self, attribute, value):
"""Check that provided email addresses are valid."""
if not all(map(lambda e: is_email_or_github_identity(e), value)):
raise ValueError(f"Invalid email or github identity in DAG email: {value}.")
@retry_delay.validator
def validate_retry_delay(self, attribute, value):
"""Check that retry_delay is in a valid timedelta format."""
validate_timedelta_string(value)
@end_date.validator
@start_date.validator
def validate_date(self, attribute, value):
"""Check that start_date and end_date has YYYY-MM-DD format."""
if value is not None and not is_date_string(value):
raise ValueError(
f"Invalid date definition for {attribute}: {value}."
"Dates should be specified as YYYY-MM-DD."
)
def to_dict(self):
"""Return class as a dict."""
return self.__dict__
@attr.s(auto_attribs=True)
class Dag:
"""
Representation of a DAG configuration.
Uses attrs to simplify the class definition and provide validation.
Docs: https://www.attrs.org
"""
name: str = attr.ib()
schedule_interval: str = attr.ib()
default_args: DagDefaultArgs
tasks: List[Task] = attr.ib([])
description: str = attr.ib("")
repo: str = attr.ib("bigquery-etl")
tags: List[str] = attr.ib([])
catchup: bool = attr.ib(False)
@name.validator
def validate_dag_name(self, attribute, value):
"""Validate the DAG name."""
if not is_valid_dag_name(value):
raise ValueError(
f"Invalid DAG name {value}. Name must start with 'bqetl_' "
f"or 'private_bqetl_'."
)
@tasks.validator
def validate_tasks(self, attribute, value):
"""Validate tasks."""
task_names = list(map(lambda t: t.task_name, value))
duplicate_task_names = set(
[task_name for task_name in task_names if task_names.count(task_name) > 1]
)
if len(duplicate_task_names) > 0:
raise ValueError(
f"Duplicate task names encountered: {duplicate_task_names}."
)
@schedule_interval.validator
def validate_schedule_interval(self, attribute, value):
"""Validate the schedule_interval format."""
if not is_schedule_interval(value):
raise ValueError(f"Invalid schedule_interval {value}.")
def add_tasks(self, tasks):
"""Add tasks to be scheduled as part of the DAG."""
self.tasks = self.tasks.copy() + tasks
self.validate_tasks(None, self.tasks)
def with_upstream_dependencies(self, dag_collection):
"""Perform a dry_run to get upstream dependencies."""
for task in self.tasks:
task.with_upstream_dependencies(dag_collection)
def with_downstream_dependencies(self, dag_collection):
"""Get downstream tasks by looking up upstream dependencies in DAG collection."""
for task in self.tasks:
task.with_downstream_dependencies(dag_collection)
def to_dict(self):
"""Return class as a dict."""
d = self.__dict__
name = d["name"]
del d["name"]
del d["tasks"]
d["default_args"] = self.default_args.to_dict()
return {name: d}
@property
def task_groups(self) -> Set[str]:
"""
Return list of task groups in this DAG.
Task groups are specified as part of the task configurations.
"""
return {task.task_group for task in self.tasks if task.task_group is not None}
@classmethod
def from_dict(cls: type, d: dict):
"""
Parse the DAG configuration from a dict and create a new Dag instance.
Expected dict format:
{
"name": {
"schedule_interval": string,
"default_args": dict
}
}
"""
if len(d.keys()) != 1:
raise DagParseException(f"Invalid DAG configuration format in {d}")
converter = cattrs.BaseConverter()
try:
name = list(d.keys())[0]
tags: set[str] = set(d[name].get("tags", []))
if not any(tag.startswith("repo/") for tag in tags):
tags.add("repo/" + d[name].get("repo", "bigquery-etl"))
if name.startswith("private_") and CONFIDENTIAL_TAG not in tags:
tags.add(
CONFIDENTIAL_TAG,
)
d[name]["tags"] = sorted(tags)
if name == PUBLIC_DATA_JSON_DAG:
return converter.structure({"name": name, **d[name]}, PublicDataJsonDag)
else:
return converter.structure({"name": name, **d[name]}, cls)
except (TypeError, AttributeError) as e:
raise DagParseException(f"Invalid DAG configuration format in {d}: {e}")
def _jinja_env(self):
"""Prepare and load custom formatters into the jinja environment."""
env = Environment(
loader=PackageLoader("bigquery_etl", "query_scheduling/templates"),
extensions=["jinja2.ext.do"],
)
# load custom formatters into Jinja env
for name in dir(formatters):
func = getattr(formatters, name)
if not callable(func):
continue
env.filters[name] = func
return env
def to_airflow_dag(self):
"""Convert the DAG to its Airflow representation and return the python code."""
if len(self.tasks) == 0:
raise InvalidDag(
f"DAG {self.name} has no tasks - cannot convert it to a valid .py DAG "
f"file. Does it appear under `scheduling` in any metadata.yaml files?"
)
env = self._jinja_env()
dag_template = env.get_template(AIRFLOW_DAG_TEMPLATE)
args = self.__dict__
args["task_groups"] = self.task_groups
args["bigeye_warehouse_id"] = ConfigLoader.get(
"monitoring", "bigeye_warehouse_id", fallback=1939
)
args["bigeye_conn_id"] = ConfigLoader.get(
"monitoring", "bigeye_conn_id", fallback="bigeye_connection"
)
return dag_template.render(args)
class PublicDataJsonDag(Dag):
"""Special DAG with tasks exporting public json data to GCS."""
def to_airflow_dag(self):
"""Convert the DAG to its Airflow representation and return the python code."""
env = self._jinja_env()
dag_template = env.get_template(PUBLIC_DATA_JSON_DAG_TEMPLATE)
args = self.__dict__
return dag_template.render(args)
def _create_export_task(self, task, dag_collection):
if not task.public_json:
raise ValueError(f"Task {task.task_name} not marked as public JSON.")
converter = cattrs.BaseConverter()
task_dict = converter.unstructure(task)
del task_dict["dataset"]
del task_dict["table"]
del task_dict["version"]
del task_dict["project"]
export_task = converter.structure(task_dict, Task)
export_task.dag_name = self.name
export_task.task_name = f"export_public_data_json_{export_task.task_name}"
task_schedule_interval = dag_collection.dag_by_name(
task.dag_name
).schedule_interval
execution_delta = schedule_interval_delta(
task_schedule_interval, self.schedule_interval
)
if execution_delta == "0s":
execution_delta = None
export_task.dependencies = [
TaskRef(
dag_name=task.dag_name,
task_id=task.task_name,
execution_delta=execution_delta,
)
]
return export_task
def add_export_tasks(self, tasks, dag_collection):
"""Add new tasks for exporting data of the original queries to GCS."""
self.add_tasks(
[self._create_export_task(task, dag_collection) for task in tasks]
)