source/generate_dag.py (161 lines of code) (raw):
# Copyright 2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# [START generate_dag]
from jinja2 import Environment, FileSystemLoader
import os
import argparse
import yaml
import re
from utils import extract_keys_from_py_file, process_condition, has_valid_default_args
config_file = ''
def raise_exception(message):
raise ValueError(message)
def reformat_yaml(yaml_data):
"""Reformats YAML data with an indentation of 4 spaces."""
return yaml.dump(yaml_data, indent=4)
def import_variables(yaml_config: dict) -> dict:
"""
Generate a dictionary for importing variables into a DAG.
This function determines how to import variables based on the
provided YAML configuration. It can import variables from a specified file, or from the YAML configuration. It also checks if 'default_args' is defined and not empty in the variables file.
Args:
yaml_config (dict): YAML configuration for the task.
Returns:
dict: A dictionary containing:
- add_variables (bool): Flag indicating whether to add variables.
- variables (list): A list of variables with preserved indentation.
- variable_keys (list): A list of unique variable keynames from the imported variables
- valid_default_args (bool): True if 'default_args' is defined and not empty, False otherwise.
"""
variables = []
variable_keys = []
valid_default_args = False # Initialize the flag here
try:
task_variables = yaml_config['task_variables']
if not isinstance(task_variables, dict):
raise TypeError("'task_variables' should be a dictionary.")
if task_variables.get('variables_file_path'):
file_path = task_variables['variables_file_path']
variable_keys = extract_keys_from_py_file(file_path)
print(f"Importing variables: reading variables from file {file_path}")
# Check for valid 'default_args' here
valid_default_args = has_valid_default_args(file_path)
with open(file_path, 'r') as file:
file_content = file.readlines() # Read all lines into a list
variables = [line for line in file_content if '# type: ignore' not in line]
except (KeyError, TypeError) as e:
print(f"Importing Variables: Error processing variables file: {e}")
if variables:
return {
"add_variables": True,
"variables": variables,
"variable_keys": variable_keys,
"valid_default_args": valid_default_args
}
else:
print("Importing variables: Skipping variable import.")
return {"add_variables": False}
def import_python_functions(yaml_config: dict) -> dict:
"""
Generate a dictionary for importing Python functions into a DAG.
This function determines how to import Python functions based on the
provided YAML configuration. It can import functions from a specified
file, from the YAML configuration, or from both. It also checks if
'custom_defined_functions' exists and is not empty. It preserves
indentation from the YAML "code" blocks.
Args:
yaml_config (dict): YAML configuration for the task.
Returns:
dict: A dictionary containing:
- add_functions (bool): Flag indicating whether to add functions.
- functions (list): A list of function code strings with preserved indentation.
"""
functions = []
try:
custom_functions = yaml_config['custom_python_functions']
if not isinstance(custom_functions, dict):
raise TypeError("'custom_python_functions' should be a dictionary.")
if custom_functions.get('import_functions_from_file', False) and custom_functions.get('functions_file_path'):
file_path = custom_functions['functions_file_path']
print(f"Importing Python Functions: reading python functions from file {file_path}")
with open(file_path, 'r') as file:
functions.append(file.read())
defined_functions = custom_functions.get('custom_defined_functions')
if defined_functions:
for i, func_data in enumerate(defined_functions.values()):
if 'code' in func_data and func_data['code'].strip():
# Split the code into lines and preserve indentation
code_lines = func_data['code'].splitlines()
functions.append('\n'.join(code_lines)) # Join the lines back
except (KeyError, TypeError) as e:
print(f"Importing Python Functions: Error processing YAML: {e}")
if functions:
return {
"add_functions": True,
"functions": functions
}
else:
print("Importing Python Functions: Skipping function import.")
return {"add_functions": False}
def get_unique_tasks(config_data):
"""
Extracts unique task IDs from a dictionary containing task and task group definitions.
Args:
config_data: A dictionary containing task and task group definitions.
Returns:
A list of unique task IDs, including task group IDs.
"""
all_task_ids = []
# Extract task IDs from regular tasks
for task in config_data.get('tasks', []):
all_task_ids.append(task['task_id'])
# Extract task IDs from task groups and their tasks
for group in config_data.get('task_groups', []):
all_task_ids.append(group['group_id']) # Add the group_id as a task
if 'tasks' in group:
for task in group['tasks']:
all_task_ids.append(task['task_id'])
# Return unique task IDs
return sorted(list(set(all_task_ids)))
def validate_create_task_dependency(yaml_config: dict) -> dict:
"""Validate and create task dependency for the DAG.
Args:
yaml_config (dict): YAML configuration file for the task.
Returns:
dict: task_dependency_type (custom or default) to create task dependency.
Raises:
ValueError: If default_task_dependency is not boolean or
if custom_task_dependency doesn't match config tasks.
"""
try:
default_dependency = yaml_config['task_dependency']['default_task_dependency']
if not isinstance(default_dependency, bool): # More concise type check
raise ValueError("Invalid default_task_dependency value. Acceptable values are True or False")
except KeyError as e:
raise ValueError(f"Missing key in yaml_config: {e}") from e # More informative error message
task_list = get_unique_tasks(yaml_config)
if default_dependency:
return {"task_dependency_type": "default"}
try:
custom_dependency = yaml_config['task_dependency']['custom_task_dependency']
except KeyError as e:
raise ValueError(f"Missing key in yaml_config: {e}") from e
print("Task Validation: validating tasks for custom dependency")
task_dependency = []
custom_tasks = set()
for dependency_chain in custom_dependency:
dependency_chain = dependency_chain.strip('"')
task_dependency.append(dependency_chain)
custom_tasks.update(re.findall(r'[\w_]+', dependency_chain))
if sorted(custom_tasks) != task_list:
print(f"List of config tasks: {task_list}")
print(f"List of custom tasks: {sorted(custom_tasks)}") # Sort for consistent comparison
raise ValueError("Validation error: Tasks in custom_task_dependency don't match config tasks")
print("Task Validation: task validation successful")
return {"task_dependency_type": "custom", "task_dependency": task_dependency}
# Read configuration file from command line
# Please refer to the documentation (README.md) to see how to author a
# configuration (YAML) file that is used by the program to generate
# Airflow DAG python file.
def configure_arg_parser():
description = '''This application creates Composer DAGs based on the config file
config.json and template. Extract Args for Run.'''
parser = argparse.ArgumentParser(description= description)
parser.add_argument('--config_file',
required=True,
help='''Provide template configuration YAML file location
e.g. ./config.yaml''')
parser.add_argument('--dag_template',
default="standard_dag",
help="Template to use for DAG generation")
options = parser.parse_args()
return options
# Generate Airflow DAG python file by reading the config (YAML) file
# that is passed to the program. This section loads a .template file
# located in the ./templates folder in the source and the template folder
# parses and dynamically generate a python file using Jinja2 template
# programming language. Please refer to Jinja documentation for Jinja
# template authoring guidelines.
def generate_dag_file(args):
config_file = args.config_file
dag_template = args.dag_template
with open(config_file,'r') as f:
# Register the tag with the YAML parser
tmp_config_data = yaml.safe_load(f)
config_data = yaml.safe_load(reformat_yaml(tmp_config_data))
config_file_name = os.path.basename(config_file)
config_data["config_file_name"] = config_file_name
config_path = os.path.abspath(config_file)
file_dir = os.path.dirname(os.path.abspath(__file__))
template_dir = os.path.join(file_dir,"templates")
dag_id = config_data['dag_id']
# Reading variables from .py variable file
dag_variables = import_variables(yaml_config=config_data)
# Reading python function from .txt file or from YAML config as per configuration
python_functions = import_python_functions(yaml_config=config_data)
# Importing task_dependency from YAML config as per configuration
task_dependency = validate_create_task_dependency(yaml_config=config_data)
# Importing variables from variables.YAML or from YAML config as per configuration
var_configs = config_data.get("task_variables")
print("Config file: {}".format(config_path))
print("Generating DAG for: {}".format(dag_template))
# Uses template renderer to load and render the Jinja template
# The template file is selected from config_data['dag_template']
# variable from the config file that is input to the program.
env = Environment(
loader=FileSystemLoader(template_dir),
lstrip_blocks=True,
)
# Consolidate functions in env.globals
env.globals.update({
'process_condition': process_condition,
'raise_exception': raise_exception,
})
template = env.get_template(dag_template+".template")
framework_config_values = {'var_configs': var_configs}
dag_path = os.path.abspath(os.path.join(os.path.dirname(config_path), '..', "dags"))
if not os.path.exists(dag_path):
os.makedirs(dag_path)
generate_file_name = os.path.join(dag_path, dag_id + '.py')
with open(generate_file_name, 'w') as fh:
fh.write(
template.render(
config_data=config_data,
framework_config_values=framework_config_values,python_functions=python_functions,
task_dependency=task_dependency,
dag_variables=dag_variables,
)
)
print("Finished generating file: {}".format(generate_file_name))
if __name__ == '__main__':
args = configure_arg_parser()
generate_dag_file(args)
# [END generate_dag]