bigquery_etl/query_scheduling/task.py (545 lines of code) (raw):
"""Represents a scheduled Airflow task."""
import copy
import logging
import os
import re
from enum import Enum
from fnmatch import fnmatchcase
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import attr
import cattrs
import click
from bigquery_etl.dependency import extract_table_references_without_views
from bigquery_etl.metadata.parse_metadata import Metadata, PartitionType
from bigquery_etl.query_scheduling.utils import (
is_date_string,
is_email,
is_email_or_github_identity,
is_schedule_interval,
is_valid_dag_name,
schedule_interval_delta,
validate_timedelta_string,
)
AIRFLOW_TASK_TEMPLATE = "airflow_task.j2"
QUERY_FILE_RE = re.compile(
r"^(?:.*/)?([a-zA-Z0-9_-]+)/([a-zA-Z0-9_]+)/"
r"([a-zA-Z0-9_]+)_(v[0-9]+)/(?:query\.sql|part1\.sql|script\.sql|query\.py|checks\.sql|bigconfig\.yml)$"
)
CHECKS_FILE_RE = re.compile(
r"^(?:.*/)?([a-zA-Z0-9_-]+)/([a-zA-Z0-9_]+)/"
r"([a-zA-Z0-9_]+)_(v[0-9]+)/(?:checks\.sql)$"
)
DEFAULT_DESTINATION_TABLE_STR = "use-default-destination-table"
MAX_TASK_NAME_LENGTH = 250
class TriggerRule(Enum):
"""Options for task trigger rules."""
ALL_SUCCESS = "all_success"
ALL_FAILED = "all_failed"
ALL_DONE = "all_done"
ONE_FAILED = "one_failed"
ONE_SUCCESS = "one_success"
NONE_FAILED = "none_failed"
NONE_SKIPPED = "none_skipped"
DUMMY = "dummy"
class TaskParseException(Exception):
"""Raised when task scheduling config is invalid."""
def __init__(self, message):
"""Throw TaskParseException."""
message = f"""
{message}
Expected yaml format in metadata.yaml:
scheduling:
dag_name: string [required]
depends_on_past: bool [optional]
... <more config parameters> ...
"""
super(TaskParseException, self).__init__(message)
class UnscheduledTask(Exception):
"""Raised when a task is not scheduled."""
pass
@attr.s(auto_attribs=True, frozen=True)
class TaskRef:
"""
Representation of a reference to another task.
The task can be defined in bigquery-etl or in telemetry-airflow.
Uses attrs to simplify the class definition and provide validation.
Docs: https://www.attrs.org
"""
dag_name: str = attr.ib()
task_id: str = attr.ib()
execution_delta: Optional[str] = attr.ib(None)
schedule_interval: Optional[str] = attr.ib(None)
date_partition_offset: Optional[int] = attr.ib(None)
task_group: Optional[str] = attr.ib(None)
@property
def task_key(self):
"""Key to uniquely identify the task."""
return (
f"{self.dag_name}.{self.task_group}.{self.task_id}"
if self.task_group
else f"{self.dag_name}.{self.task_id}"
)
@execution_delta.validator
def validate_execution_delta(self, attribute, value):
"""Check that execution_delta is in a valid timedelta format."""
if value is not None:
validate_timedelta_string(value)
@schedule_interval.validator
def validate_schedule_interval(self, attribute, value):
"""Validate the schedule_interval format."""
if value is not None and not is_schedule_interval(value):
raise ValueError(f"Invalid schedule_interval {value}.")
def get_execution_delta(self, schedule_interval):
"""Determine execution_delta, via schedule_interval if necessary."""
if self.execution_delta is not None:
return self.execution_delta
elif self.schedule_interval is not None and schedule_interval is not None:
execution_delta = schedule_interval_delta(
self.schedule_interval, schedule_interval
)
if execution_delta != "0s":
return execution_delta
return None
@attr.s(auto_attribs=True, frozen=True)
class TableSensorTask:
"""Representation of a sensor task to wait for a table to exist."""
task_id: str = attr.ib()
table_id: str = attr.ib()
poke_interval: Optional[str] = attr.ib(None, kw_only=True)
timeout: Optional[str] = attr.ib(None, kw_only=True)
retries: Optional[int] = attr.ib(None, kw_only=True)
retry_delay: Optional[str] = attr.ib(None, kw_only=True)
@task_id.validator
def validate_task_id(self, attribute, value):
"""Validate the task ID."""
if len(value) < 1 or len(value) > MAX_TASK_NAME_LENGTH:
raise ValueError(
f"Invalid task ID '{value}'."
f" The task ID has to be 1 to {MAX_TASK_NAME_LENGTH} characters long."
)
if not re.fullmatch(r"\w+", value):
raise ValueError(
f"Invalid task ID '{value}'."
f" The task ID may only contain alphanumerics and underscores."
)
@table_id.validator
def validate_table_id(self, attribute, value):
"""Check that `table_id` is a fully qualified table ID."""
if value.count(".") != 2:
raise ValueError(
f"Invalid table ID '{value}'."
" Table IDs must be fully qualified with the project and dataset."
)
@poke_interval.validator
def validate_poke_interval(self, attribute, value):
"""Check that `poke_interval` is a valid timedelta string."""
if value is not None:
validate_timedelta_string(value)
@timeout.validator
def validate_timeout(self, attribute, value):
"""Check that `timeout` is a valid timedelta string."""
if value is not None:
validate_timedelta_string(value)
@retry_delay.validator
def validate_retry_delay(self, attribute, value):
"""Check that `retry_delay` is a valid timedelta string."""
if value is not None:
validate_timedelta_string(value)
@attr.s(auto_attribs=True, frozen=True)
class TablePartitionSensorTask(TableSensorTask):
"""Representation of a sensor task to wait for a table partition to exist."""
partition_id: str = attr.ib()
@attr.s(auto_attribs=True, frozen=True)
class FivetranTask:
"""Representation of a Fivetran data import task."""
task_id: str = attr.ib()
class SecretDeployType(Enum):
"""Specifies how secret should be exposed in Airflow."""
ENV = "env"
VOLUME = "volume"
@attr.s(auto_attribs=True, frozen=True)
class Secret:
"""Represents the secret configuration used to expose credentials in the task."""
deploy_target: str
key: str
deploy_type: str = attr.ib("env")
secret: str = attr.ib("airflow-gke-secrets")
@deploy_type.validator
def validate_deploy_type(self, attribute, value):
"""Check that deploy_type is a valid option."""
if value is not None and value not in set(
deploy_type.value for deploy_type in SecretDeployType
):
raise ValueError(
f"Invalid deploy_type {value}. Needs to be either 'env' or 'volume'."
)
@attr.s(auto_attribs=True, frozen=True)
class TaskContainerResources:
"""Represents the Kubernetes container resources configuration for a task."""
# example dict: {"memory": "4Gi", "cpu": "500m"}
requests: Optional[Dict[str, str]] = attr.ib(None)
limits: Optional[Dict[str, str]] = attr.ib(None)
# Known tasks in telemetry-airflow, like stable table tasks
# https://github.com/mozilla/telemetry-airflow/blob/main/dags/copy_deduplicate.py
EXTERNAL_TASKS = {
TaskRef(
dag_name="copy_deduplicate",
task_id="copy_deduplicate_main_ping",
schedule_interval="0 1 * * *",
): [
"telemetry_stable.main_v4",
"telemetry_stable.main_v5",
"telemetry_stable.main_use_counter_v4",
],
TaskRef(
dag_name="copy_deduplicate",
task_id="copy_deduplicate_first_shutdown_ping",
schedule_interval="0 1 * * *",
): [
"telemetry_stable.first_shutdown_v4",
"telemetry_stable.first_shutdown_v5",
"telemetry_stable.first_shutdown_use_counter_v4",
],
TaskRef(
dag_name="copy_deduplicate",
task_id="copy_deduplicate_saved_session_ping",
schedule_interval="0 1 * * *",
): [
"telemetry_stable.saved_session_v4",
"telemetry_stable.saved_session_v5",
"telemetry_stable.saved_session_use_counter_v4",
],
TaskRef(
dag_name="copy_deduplicate",
task_id="bq_main_events",
schedule_interval="0 1 * * *",
): ["telemetry_derived.main_events_v1"],
TaskRef(
dag_name="copy_deduplicate",
task_id="event_events",
schedule_interval="0 1 * * *",
): ["telemetry_derived.event_events_v1"],
TaskRef(
dag_name="copy_deduplicate",
task_id="telemetry_derived__core_clients_first_seen__v1",
schedule_interval="0 1 * * *",
): ["*.core_clients_first_seen*"],
TaskRef(
dag_name="copy_deduplicate",
task_id="copy_deduplicate_all",
schedule_interval="0 1 * * *",
): ["*_stable.*"],
}
@attr.s(auto_attribs=True)
class Task:
"""
Representation of a task scheduled in Airflow.
Uses attrs to simplify the class definition and provide validation.
Docs: https://www.attrs.org
"""
dag_name: str = attr.ib()
query_file: str
owner: str = attr.ib()
email: List[str] = attr.ib([])
task_name: Optional[str] = attr.ib(None)
project: str = attr.ib(init=False)
dataset: str = attr.ib(init=False)
table: str = attr.ib(init=False)
version: str = attr.ib(init=False)
depends_on_past: bool = attr.ib(False)
start_date: Optional[str] = attr.ib(None)
date_partition_parameter: Optional[str] = "submission_date"
table_partition_template: Optional[str] = None
# number of days date partition parameter should be offset
date_partition_offset: Optional[int] = None
# indicate whether data should be published as JSON
public_json: bool = attr.ib(False)
# manually specified upstream dependencies
depends_on: List[TaskRef] = attr.ib([])
depends_on_tables_existing: List[TableSensorTask] = attr.ib([])
depends_on_table_partitions_existing: List[TablePartitionSensorTask] = attr.ib([])
depends_on_fivetran: List[FivetranTask] = attr.ib([])
# task trigger rule, used to override default of "all_success"
trigger_rule: Optional[str] = attr.ib(None)
# manually specified downstream depdencies
external_downstream_tasks: List[TaskRef] = attr.ib([])
# automatically determined upstream and downstream dependencies
upstream_dependencies: List[TaskRef] = attr.ib([])
downstream_dependencies: List[TaskRef] = attr.ib([])
arguments: List[str] = attr.ib([])
parameters: List[str] = attr.ib([])
multipart: bool = attr.ib(False)
query_file_path: Optional[str] = None
priority: Optional[int] = None
referenced_tables: Optional[List[Tuple[str, str, str]]] = attr.ib(None)
destination_table: Optional[str] = attr.ib(default=DEFAULT_DESTINATION_TABLE_STR)
is_python_script: bool = attr.ib(False)
is_dq_check: bool = attr.ib(False)
# Failure of the checks task will stop the dag from executing further
is_dq_check_fail: bool = attr.ib(True)
is_bigeye_check: bool = attr.ib(False)
task_concurrency: Optional[int] = attr.ib(None)
retry_delay: Optional[str] = attr.ib(None)
retries: Optional[int] = attr.ib(None)
email_on_retry: Optional[bool] = attr.ib(None)
gcp_conn_id: Optional[str] = attr.ib(None)
gke_project_id: Optional[str] = attr.ib(None)
gke_location: Optional[str] = attr.ib(None)
gke_cluster_name: Optional[str] = attr.ib(None)
query_project: Optional[str] = attr.ib(None)
task_group: Optional[str] = attr.ib(None)
container_resources: Optional[TaskContainerResources] = attr.ib(None)
node_selector: Optional[Dict[str, str]] = attr.ib(None)
startup_timeout_seconds: Optional[int] = attr.ib(None)
secrets: Optional[List[Secret]] = attr.ib(None)
monitoring_enabled: Optional[bool] = attr.ib(False)
@property
def task_key(self):
"""Key to uniquely identify the task."""
return f"{self.dag_name}.{self.task_name}"
@owner.validator
def validate_owner(self, attribute, value):
"""Check that owner is a valid email address."""
if not is_email_or_github_identity(value):
raise ValueError(
f"Invalid email or github identity for task owner: {value}."
)
@email.validator
def validate_email(self, attribute, value):
"""Check that provided email addresses are valid."""
if not all(map(lambda e: is_email_or_github_identity(e), value)):
raise ValueError(f"Invalid email or github identity in DAG email: {value}.")
@start_date.validator
def validate_start_date(self, attribute, value):
"""Check that start_date has YYYY-MM-DD format."""
if value is not None and not is_date_string(value):
raise ValueError(
f"Invalid date definition for {attribute}: {value}."
"Dates should be specified as YYYY-MM-DD."
)
@dag_name.validator
def validate_dag_name(self, attribute, value):
"""Validate the DAG name."""
if not is_valid_dag_name(value):
raise ValueError(
f"Invalid DAG name {value} for task. Name must start with 'bqetl_' "
f"or 'private_bqetl_'."
)
@task_name.validator
def validate_task_name(self, attribute, value):
"""Validate the task name."""
if value is not None:
if len(value) < 1 or len(value) > MAX_TASK_NAME_LENGTH:
raise ValueError(
f"Invalid task name {value}. "
f"The task name has to be 1 to {MAX_TASK_NAME_LENGTH} characters long."
)
@trigger_rule.validator
def validate_trigger_rule(self, attribute, value):
"""Check that trigger_rule is a valid option."""
if value is not None and value not in set(rule.value for rule in TriggerRule):
raise ValueError(
f"Invalid trigger rule {value}. "
"See https://airflow.apache.org/docs/apache-airflow/1.10.3/concepts.html#trigger-rules for list of trigger rules"
)
@retry_delay.validator
def validate_retry_delay(self, attribute, value):
"""Check that retry_delay is in a valid timedelta format."""
if value is not None:
validate_timedelta_string(value)
@task_group.validator
def validate_task_group(self, attribute, value):
"""Check that the task group name is valid."""
if value is not None and not re.match(r"[a-zA-Z0-9_]+", value):
raise ValueError(
"Invalid task group identifier. Group name must match pattern [a-zA-Z0-9_]+"
)
def __attrs_post_init__(self):
"""Extract information from the query file name."""
query_file_re = re.search(QUERY_FILE_RE, self.query_file)
if query_file_re:
self.project = query_file_re.group(1)
self.dataset = query_file_re.group(2)
self.table = query_file_re.group(3)
self.version = query_file_re.group(4)
if self.task_name is None:
# limiting task name to allow longer dataset names
self.task_name = f"{self.dataset}__{self.table}__{self.version}"[
-MAX_TASK_NAME_LENGTH:
]
self.validate_task_name(None, self.task_name)
if self.destination_table == DEFAULT_DESTINATION_TABLE_STR:
self.destination_table = f"{self.table}_{self.version}"
if self.destination_table is None and self.query_file_path is None:
raise ValueError(
"One of destination_table or query_file_path must be specified"
)
else:
raise ValueError(
"query_file must be a path with format:"
" sql/<project>/<dataset>/<table>_<version>"
"/(query.sql|part1.sql|script.sql|query.py|checks.sql)"
f" but is {self.query_file}"
)
@classmethod
def of_query(cls, query_file, metadata=None, dag_collection=None):
"""
Create task that schedules the corresponding query in Airflow.
Raises FileNotFoundError if not metadata file exists for query.
If `metadata` is set, then it is used instead of the metadata.yaml
file that might exist alongside the query file.
"""
converter = cattrs.BaseConverter()
if metadata is None:
metadata = Metadata.of_query_file(query_file)
dag_name = metadata.scheduling.get("dag_name")
if dag_name is None:
raise UnscheduledTask(
f"Metadata for {query_file} does not contain scheduling information."
)
task_config = {"query_file": str(query_file)}
task_config.update(metadata.scheduling)
if len(metadata.owners) <= 0:
raise TaskParseException(
f"No owner specified in metadata for {query_file}."
)
# Airflow only allows to set one owner, so we just take the first
task_config["owner"] = metadata.owners[0]
# Get default email from default_args if available
default_email = []
if dag_collection is not None:
dag = dag_collection.dag_by_name(dag_name)
if dag is not None:
default_email = dag.default_args.email
email = task_config.get("email", default_email)
# Remove non-valid emails from owners e.g. Github identities and add to
# Airflow email list.
for owner in metadata.owners:
if not is_email(owner):
metadata.owners.remove(owner)
click.echo(
f"{owner} removed from email list in DAG {metadata.scheduling['dag_name']}"
)
task_config["email"] = list(set(email + metadata.owners))
# expose secret config
task_config["secrets"] = metadata.scheduling.get("secrets", [])
# to determine if BigEye task should be generated
if metadata.monitoring:
task_config["monitoring_enabled"] = metadata.monitoring.enabled
# data processed in task should be published
if metadata.is_public_json():
task_config["public_json"] = True
# Override the table_partition_template if there is no `destination_table`
# set in the scheduling section of the metadata. If not then pass a jinja
# template that reformats the date string used for table partition decorator.
# See doc here for formatting conventions:
# https://cloud.google.com/bigquery/docs/managing-partitioned-table-data#partition_decorators
if (
metadata.bigquery
and metadata.bigquery.time_partitioning
and metadata.scheduling.get("destination_table") is None
):
match metadata.bigquery.time_partitioning.type:
case PartitionType.YEAR:
partition_template = '${{ dag_run.logical_date.strftime("%Y") }}'
case PartitionType.MONTH:
partition_template = '${{ dag_run.logical_date.strftime("%Y%m") }}'
case PartitionType.DAY:
# skip for the default case of daily partitioning
partition_template = None
case PartitionType.HOUR:
partition_template = (
'${{ dag_run.logical_date.strftime("%Y%m%d%H") }}'
)
case _:
raise TaskParseException(
f"Invalid partition type: {metadata.bigquery.time_partitioning.type}"
)
if partition_template:
task_config["table_partition_template"] = partition_template
try:
return copy.deepcopy(converter.structure(task_config, cls))
except TypeError as e:
raise TaskParseException(
f"Invalid scheduling information format for {query_file}: {e}"
)
@classmethod
def of_multipart_query(cls, query_file, metadata=None, dag_collection=None):
"""
Create task that schedules the corresponding multipart query in Airflow.
Raises FileNotFoundError if not metadata file exists for query.
If `metadata` is set, then it is used instead of the metadata.yaml
file that might exist alongside the query file.
"""
task = cls.of_query(query_file, metadata, dag_collection)
task.multipart = True
task.query_file_path = os.path.dirname(query_file)
return task
@classmethod
def of_script(cls, query_file, metadata=None, dag_collection=None):
"""
Create task that schedules the corresponding script in Airflow.
Raises FileNotFoundError if no metadata file exists for query.
If `metadata` is set, then it is used instead of the metadata.yaml
file that might exist alongside the query file.
"""
task = cls.of_query(query_file, metadata, dag_collection)
task.query_file_path = query_file
task.destination_table = None
return task
@classmethod
def of_python_script(cls, query_file, metadata=None, dag_collection=None):
"""
Create a task that schedules the Python script file in Airflow.
Raises FileNotFoundError if no metadata file exists for query.
If `metadata` is set, then it is used instead of the metadata.yaml
file that might exist alongside the query file.
"""
task = cls.of_query(query_file, metadata, dag_collection)
task.query_file_path = query_file
task.is_python_script = True
return task
@classmethod
def of_dq_check(cls, query_file, is_check_fail, metadata=None, dag_collection=None):
"""Create a task that schedules DQ check file in Airflow."""
task = cls.of_query(query_file, metadata, dag_collection)
task.query_file_path = query_file
task.is_dq_check = True
task.is_dq_check_fail = is_check_fail
task.depends_on_past = False
task.retries = 0
task.depends_on_fivetran = []
task.referenced_tables = None
task.depends_on = []
if task.is_dq_check_fail:
task.task_name = (
f"checks__fail_{task.dataset}__{task.table}__{task.version}"[
-MAX_TASK_NAME_LENGTH:
]
)
task.validate_task_name(None, task.task_name)
else:
task.task_name = (
f"checks__warn_{task.dataset}__{task.table}__{task.version}"[
-MAX_TASK_NAME_LENGTH:
]
)
task.validate_task_name(None, task.task_name)
return task
@classmethod
def of_bigeye_check(cls, query_file, metadata=None, dag_collection=None):
"""Create a task to trigger BigEye metric run via Airflow."""
task = cls.of_query(query_file, metadata, dag_collection)
task.query_file_path = None
task.is_bigeye_check = True
task.depends_on_past = False
task.destination_table = None
task.retries = 1
task.depends_on_fivetran = []
task.referenced_tables = None
task.depends_on = []
if task.is_bigeye_check:
task.task_name = f"bigeye__{task.dataset}__{task.table}__{task.version}"[
-MAX_TASK_NAME_LENGTH:
]
task.validate_task_name(None, task.task_name)
return task
def to_ref(self, dag_collection):
"""Return the task as `TaskRef`."""
return TaskRef(
dag_name=self.dag_name,
task_id=self.task_name,
date_partition_offset=self.date_partition_offset,
schedule_interval=dag_collection.dag_by_name(
self.dag_name
).schedule_interval,
task_group=self.task_group,
)
def _get_referenced_tables(self):
"""Use sqlglot to get tables the query depends on."""
logging.info(f"Get dependencies for {self.task_key}")
if self.is_python_script or self.is_bigeye_check:
# cannot do dry runs for python scripts or BigEye config
return self.referenced_tables or []
if self.referenced_tables is None:
query_files = [Path(self.query_file)]
if self.multipart:
# dry_run all files if query is split into multiple parts
query_files = Path(self.query_file_path).glob("*.sql")
table_names = {
tuple(table.split("."))
for query_file in query_files
for table in extract_table_references_without_views(query_file)
}
# the order of table dependencies changes between requests
# sort to maintain same order between DAG generation runs
self.referenced_tables = sorted(table_names)
return self.referenced_tables
def with_upstream_dependencies(self, dag_collection):
"""Perform a dry_run to get upstream dependencies."""
if self.upstream_dependencies:
return
dependencies = []
def _duplicate_dependency(task_ref):
return any(
d.task_key == task_ref.task_key for d in self.depends_on + dependencies
)
parent_task = None
if self.is_dq_check or self.is_bigeye_check:
parent_task = dag_collection.task_for_table(
self.project, self.dataset, f"{self.table}_{self.version}"
)
parent_task_ref = parent_task.to_ref(dag_collection)
if not _duplicate_dependency(parent_task_ref):
dependencies.append(parent_task_ref)
for table in self._get_referenced_tables():
# check if upstream task is accompanied by a check
# the task running the check will be set as the upstream task instead
checks_upstream_task = dag_collection.fail_checks_task_for_table(
table[0], table[1], table[2]
)
bigeye_checks_upstream_task = (
dag_collection.fail_bigeye_checks_task_for_table(
table[0], table[1], table[2]
)
)
upstream_task = dag_collection.task_for_table(table[0], table[1], table[2])
if upstream_task is not None:
if upstream_task != self and upstream_task != parent_task:
if checks_upstream_task is not None:
upstream_task = checks_upstream_task
if bigeye_checks_upstream_task is not None:
upstream_task = bigeye_checks_upstream_task
task_ref = upstream_task.to_ref(dag_collection)
if not _duplicate_dependency(task_ref):
# Get its upstream dependencies so its date_partition_offset gets set.
upstream_task.with_upstream_dependencies(dag_collection)
task_ref = upstream_task.to_ref(dag_collection)
dependencies.append(task_ref)
else:
# see if there are some static dependencies
for task_ref, patterns in EXTERNAL_TASKS.items():
if any(fnmatchcase(f"{table[1]}.{table[2]}", p) for p in patterns):
if not _duplicate_dependency(task_ref):
dependencies.append(task_ref)
break # stop after the first match
if (
self.date_partition_parameter is not None
and self.date_partition_offset is None
):
# adjust submission_date parameter based on whether upstream tasks have
# date partition offsets
date_partition_offsets = [
dependency.date_partition_offset
for dependency in dependencies
if dependency.date_partition_offset
]
if len(date_partition_offsets) > 0:
self.date_partition_offset = min(date_partition_offsets)
# unset the table_partition_template property if we have an offset
# as that will be overridden in the template via `destination_table`
self.table_partition_template = None
date_partition_offset_task_keys = [
dependency.task_key
for dependency in dependencies
if dependency.date_partition_offset == self.date_partition_offset
]
logging.info(
f"Set {self.task_key} date partition offset"
f" to {self.date_partition_offset}"
f" based on {', '.join(date_partition_offset_task_keys)}."
)
self.upstream_dependencies = dependencies
def with_downstream_dependencies(self, dag_collection):
"""Get downstream tasks by looking up upstream dependencies in DAG collection."""
self.downstream_dependencies = [
task_ref
for task_ref in dag_collection.get_task_downstream_dependencies(self)
if task_ref.dag_name != self.dag_name
]