src/common/materializer/dependent_dags.py (200 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines utility functions to help construct task dependent DAGs.""" from collections import defaultdict import dataclasses import datetime import graphlib import logging from pathlib import Path import textwrap from typing import Optional from common.materializer import dag_types from common.materializer import generate_assets from common.py_libs import constants from common.py_libs import cortex_bq_client from common.py_libs import dag_generator from common.py_libs import jinja def _get_parent_task_ids(parent_files: list[str]) -> list[str]: task_ids = [] for p in parent_files: if p == "start_task": task_ids.append(p) else: task_ids.append(f"refresh_{Path(p).stem}") return task_ids def _get_dag_setting(obj: dag_types.BqObject) -> Optional[dag_types.Dag]: if not (obj.table_setting and obj.table_setting.dag_setting): return None return obj.table_setting.dag_setting def get_validated_load_freq( top_level_objs: dict[str, dag_types.BqObject]) -> str: """Return validated load_frequency defined by top level nodes. Top level nodes must have a consistent schedule defined. Multiple nodes can define the same load_frequence, or nodes can leave it unset as long as on of the top level nodes has it set. Args: top_level_objs: map of top level BQ objects that are direct children of the DAG root keyed by sql file name. Returns: Validated load_frequency string defining DAG schedule. Raises: ValueError: if the top level nodes don't have any defined load_frequency or if they conflict. """ schedule = None for sql_file, obj in top_level_objs.items(): assert obj.table_setting is not None if obj.table_setting.load_frequency is None: continue load_frequency = obj.table_setting.load_frequency if schedule is None: schedule = load_frequency continue if load_frequency != schedule: raise ValueError( "Top level node deteced with conflicting load_frequency. " f"{sql_file} set to '{load_frequency}' but expected " f"'{schedule}'. Please resolve any load_frequency conflicts " f"among top level nodes: {list(top_level_objs.keys())}") if schedule is None: raise ValueError( "At least one top level node must define load_frequency. Top level " f"nodes include: {list(top_level_objs.keys())}") return schedule def get_task_deps(bq_object_settings: dict) -> dict[str, dag_types.BqObject]: """Get a dictionary of task dependent objects from reporting settings. Any object that has dag_setting defined is considered task dependent, even if there are no defined edges. Args: bq_object_settings: A dict representing BQ [in]dependent objects defined in a reportings settings yaml file. Returns: A dict of task dependent BQ objects keyed by their SQL file. """ reporting_objects = dag_types.ReportingObjects.from_dict(bq_object_settings) # Populate map of BQ objects and set of objects that are parent or children # deps. obj_map = {} dep_files = set() all_objs = [] if reporting_objects.bq_independent_objects: all_objs.extend(reporting_objects.bq_independent_objects) if reporting_objects.bq_dependent_objects: all_objs.extend(reporting_objects.bq_dependent_objects) for obj in all_objs: obj_map[obj.sql_file] = obj # Objects without dag_setting are ignored and considered independent. dag_setting = _get_dag_setting(obj) if not dag_setting: continue dep_files.add(obj.sql_file) if dag_setting.parents: dep_files = dep_files.union(dag_setting.parents) return {k: v for k, v in obj_map.items() if k in dep_files} @dataclasses.dataclass class DependentDagGenerator: """Generates DAG files with task dependencies. Attributes: module_name: Name of module (e.g. "cm360", "sap") target_dataset_name: Name of BQ dataset - e.g. "my_project.my_dataset" target_dataset_type: Type of dataset - e.g. "reporting" or "cdc". allow_telemetry: Cortex config file option to enable telemetry. location: Location used for BigQueryInsertJob operators. output_dir: Directory to write the generated files. jinja_data_file: File containing jinja substitutions used to create generated SQL. """ module_name: str target_dataset_name: str target_dataset_type: str allow_telemetry: bool location: str output_dir: Path jinja_data_file: Path def _generate_sql_file(self, sql_file: str, dag_name: str) -> Path: """Returns the relative path to the generated table refresh sql file.""" table_name = Path(sql_file).stem full_table_name = f"{self.target_dataset_name}.{table_name}" bq_client = cortex_bq_client.CortexBQClient() # Generate core sql text from sql file after applying Jinja parameters. core_sql = jinja.apply_jinja_params_to_file(sql_file, str(self.jinja_data_file)) generate_assets.validate_sql(bq_client, core_sql) refresh_query = generate_assets.generate_table_refresh_sql( bq_client, full_table_name, core_sql) relative_sql_file = Path(table_name).with_suffix(".sql") full_sql_file = self.output_dir / dag_name / relative_sql_file full_sql_file.parent.mkdir(exist_ok=True, parents=True) full_sql_file.write_text(refresh_query, encoding="utf-8") logging.info("Generated DAG SQL file : %s", full_sql_file) return relative_sql_file def _generate_dag_file(self, dag_name: str, ordered_nodes: list[str], task_dep_objs: dict[str, dag_types.BqObject], load_frequency: str) -> Path: logging.info("Generating DAG file for: %s", dag_name) parent_dir = Path(__file__).resolve().parent template_dir = parent_dir / "templates" header_template = (template_dir / "airflow_task_dep_dag_template_reporting.py") bq_op_template = template_dir / "bq_insert_job_template.txt" dag_full_name = "_".join( [self.target_dataset_name.replace(".", "_"), dag_name]) today = datetime.datetime.now() # General template substitutions. subs = { "dag_full_name": dag_full_name, "module_name": self.module_name, "tgt_dataset_type": self.target_dataset_type, "load_frequency": load_frequency, "year": today.year, "month": today.month, "day": today.day, "runtime_labels_dict": "", # A place holder for label dict string, "bq_location": self.location } if self.allow_telemetry: subs["runtime_labels_dict"] = str(constants.CORTEX_JOB_LABEL) if self.target_dataset_type == "reporting": subs["tags"] = [self.module_name, self.target_dataset_type] # Create DAG header. dag_header = dag_generator.generate_str_from_template( header_template, **subs) # Create BQ ops and edges. bq_op_strs = [] edge_strs = [] for sql_file in ordered_nodes: # Start tasks don't have parent edges or BQ operators. if sql_file == "start_task": continue table_name = Path(sql_file).stem if sql_file == "stop_task": task_id = sql_file else: task_id = f"refresh_{table_name}" # Generate edges. obj = task_dep_objs[sql_file] parents = _get_parent_task_ids( obj.table_setting.dag_setting.parents) # type: ignore assert parents is not None if len(parents) > 1: parents_str = ", ".join(sorted(parents)) edge_str = f"[{parents_str}] >> {task_id}" else: edge_str = f"{parents[0]} >> {task_id}" edge_str = textwrap.indent(edge_str, " " * 4) edge_strs.append(edge_str) # Don't generate BQ operator for stop tasks . if sql_file == "stop_task": continue generated_sql_file = self._generate_sql_file(sql_file, dag_name) # SQL File specific substitutions. subs["table_name"] = table_name subs["query_file"] = generated_sql_file # Generate BQ operators. bq_op = dag_generator.generate_str_from_template( bq_op_template, **subs) bq_op = textwrap.indent(bq_op, " " * 4) bq_op_strs.append(bq_op) generated_dag = "\n\n".join([dag_header, *bq_op_strs, *edge_strs]) generated_dag_file = (self.output_dir / dag_name / dag_full_name).with_suffix(".py") generated_dag_file.parent.mkdir(parents=True, exist_ok=True) generated_dag_file.write_text(generated_dag) logging.info("Generated DAG py file : %s", generated_dag_file) return generated_dag_file def create_dep_dag(self, dag_name: str, task_dep_objs: dict[str, dag_types.BqObject]) -> Path: """Creates a single DAG from a map of task dependent objects. Args: dag_name: Name of the DAG to generate. task_dep_objs: map of sql file names to their BQ Object settings. This argument is mutated to add a "start_task" parent to all top level nodes and a "stop_task" key with leaf nodes as parents. Returns: Path to generated DAG file. Raises: ValueError: If a cycle is detected in the DAG. RuntimeError: If a task dependent object is provided without dag_setting defined. """ topo_sorter = graphlib.TopologicalSorter() # Top level nodes only have the root as a parent dependency top_level_nodes = {} leaf_nodes = set(task_dep_objs.keys()) for sql_file, obj in task_dep_objs.items(): dag_setting = _get_dag_setting(obj) if not dag_setting: raise RuntimeError( "Task dependent object must have dag_setting defined: " f"{obj}") # Add an inferred root parent for top level nodes. if not dag_setting.parents: top_level_nodes[sql_file] = obj dag_setting.parents = ["start_task"] leaf_nodes = leaf_nodes.difference(dag_setting.parents) topo_sorter.add(sql_file, *dag_setting.parents) # Add an inferred stop node topo_sorter.add("stop_task", *leaf_nodes) task_dep_objs["stop_task"] = dag_types.BqObject( type=dag_types.BqObjectType.BQ_OBJECT_TYPE_UNSPECIFIED, table_setting=dag_types.Table(dag_setting=dag_types.Dag( name=dag_name, parents=list(leaf_nodes)))) load_freq = get_validated_load_freq(top_level_nodes) try: ordered_nodes = [*topo_sorter.static_order()] except graphlib.CycleError as e: raise ValueError(f"Cyclic dependency detected in DAG: {dag_name}. " "DAG must not have cycles.") from e return self._generate_dag_file(dag_name, ordered_nodes, task_dep_objs, load_freq) def create_dep_dags( self, task_dep_objs: dict[str, dag_types.BqObject]) -> list[Path]: """Creates a series of DAGs from a dict of objs keyed by sql file. Returns: List of paths to generated DAG files. """ dags = defaultdict(dict) output_files = [] # Group objects into DAGs by name. for sql_file, obj in task_dep_objs.items(): if not (obj.table_setting and obj.table_setting.dag_setting and obj.table_setting.dag_setting.name): raise RuntimeError("Task dependent object was detected but " f"dag_setting.name is undefined: {obj}") dag_name = obj.table_setting.dag_setting.name dag_name = dag_name.lower().replace(" ", "_") dags[dag_name][sql_file] = obj for dag_name, obj in dags.items(): output_files.append(self.create_dep_dag(dag_name, obj)) return output_files