tools/cloud-composer-dag-validation/dag_validation.py (124 lines of code) (raw):
# Copyright 2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import unittest
import re
import ast
from datetime import timedelta
from airflow.models import DagBag
from airflow.utils.dag_cycle_tester import check_cycle
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format="%(asctime)s %(message)s")
def has_top_level_code(file_path):
"""Check to see if a file has top level functions defined"""
with open(file_path, "r") as file:
try:
parsed_code = ast.parse(file.read())
for node in parsed_code.body:
if isinstance(node, (ast.FunctionDef)):
return False
return True
except SyntaxError:
# Syntax error in the file, it doesn't have top-level code
logger.info("Syntax Error")
return True
class TestDagIntegrity(unittest.TestCase):
"""
Class that holds all DAG Integrity tests.
"""
LOAD_SECOND_THRESHOLD = 2
MIN_RETRY = 1
MAX_RETRY = 4
def setUp(self):
dags_dir = os.getenv("INPUT_DAGPATHS", default="dags/")
logger.info(f"DAGs dir : {dags_dir}")
self.dagbag = DagBag(dag_folder=dags_dir, include_examples=False)
def test_no_import_errors(self):
"""Check to see if a DAG has import errors."""
import_error = len(self.dagbag.import_errors) == 0
error_msg = "DAG Import Errors."
assert import_error, error_msg
def test_dag_loads_within_threshold(self):
"""
Check to see if a collection (bag) of DAGs will load faster than
the specified threshold.
"""
duration = sum(
(o.duration for o in self.dagbag.dagbag_stats), timedelta()
).total_seconds()
logger.info("Duration = " + str(duration))
self.assertTrue(
duration <= self.LOAD_SECOND_THRESHOLD,
"DAG Bag load time is above the given threshold.",
)
def test_dag_task_cycle(self):
"""Check to see if a task cycle exists a DAG."""
no_dag_found = True
for dag in self.dagbag.dags:
no_dag_found = False
check_cycle(self.dagbag.dags[dag]) # Throws if a task cycle is found.
if no_dag_found:
raise AssertionError("Module does not contain a valid DAG")
def test_dag_toplevelcode(self):
"""Check if DAG file has top level code."""
for dag in self.dagbag.dags:
fileloc = self.dagbag.dags[dag].fileloc
if fileloc.endswith(".py"):
error_msg = f"DAG {dag}: Top-level code exists."
assert has_top_level_code(fileloc), error_msg
def test_task_count(self):
"""Check task count for a DAG."""
for dag in self.dagbag.dags:
tasks = len(self.dagbag.dags[dag].tasks) > 0
error_msg = f"DAG {dag}: doesn't have any tasks."
assert tasks, error_msg
def test_valid_schedule_interval(self):
"""Check to see if a DAG has a valid schedule interval."""
valid_cron_expressions = re.compile(
"(@(annually|yearly|monthly|weekly|daily|hourly|reboot))|(@every (\d+(ns|us|µs|ms|s|m|h))+)|((((\d+,)+\d+|([\d\*]+(\/|-)\d+)|\d+|\*) ?){5,7})" # noqa
)
for dag in self.dagbag.dags:
schedule = self.dagbag.dags[dag].schedule_interval
if schedule:
valid = re.match(valid_cron_expressions, str(schedule))
error_msg = f"DAG {dag} has invalid cron expression or no schedule."
assert valid, error_msg
def test_owner_present(self):
"""Check to see if a DAG has an owner set in the default arguments."""
for dag in self.dagbag.dags:
owner = self.dagbag.dags[dag].default_args.get("owner")
error_msg = f"DAG {dag}: owner not set in default_args."
assert owner, error_msg
def test_sla_present(self):
"""Check to see if a DAG has an SLA set in the default arguments."""
for dag in self.dagbag.dags:
sla = self.dagbag.dags[dag].default_args.get("sla")
error_msg = f"DAG {dag}: sla not set in default_args."
assert sla, error_msg
def test_sla_less_than_timeout(self):
"""Check to see if a DAG has an SLA less than its dagrun_timeout."""
for dag in self.dagbag.dags:
sla = self.dagbag.dags[dag].default_args.get("sla")
dagrun_timeout = self.dagbag.dags[dag].dagrun_timeout
error_msg = f"DAG {dag}: sla is greater than dagrun_timeout."
assert dagrun_timeout > sla, error_msg
def test_retries_present(self):
"""Check to see if a DAG has retries set within a given range."""
for dag in self.dagbag.dags:
retries = self.dagbag.dags[dag].default_args.get("retries", [])
error_msg = f"DAG {dag}: retries not set within specified range {self.MIN_RETRY}-{self.MAX_RETRY}." # noqa
assert retries >= self.MIN_RETRY and retries <= self.MAX_RETRY, error_msg
def test_retry_delay_present(self):
"""Check to see if a DAG has a retry delay."""
for dag in self.dagbag.dags:
retry_delay = self.dagbag.dags[dag].default_args.get("retry_delay", [])
error_msg = f"DAG {dag}: retry delay not set."
assert retry_delay, error_msg
def test_catchup_false(self):
"""Check to see if a DAG has catchup set to false."""
for dag in self.dagbag.dags:
catchup = self.dagbag.dags[dag].catchup
error_msg = f"DAG {dag}: catchup not set to False."
assert not catchup, error_msg
def test_dag_timeout_set(self):
"""Check to see if a DAG has a timeout set."""
for dag in self.dagbag.dags:
dagrun_timeout = self.dagbag.dags[dag].dagrun_timeout
error_msg = f"DAG {dag}: dagrun_timeout not set."
assert dagrun_timeout, error_msg
def test_dag_description_set(self):
"""Check to see if a DAG has a description set."""
for dag in self.dagbag.dags:
description = self.dagbag.dags[dag].description
error_msg = f"DAG {dag}: description not set."
assert description, error_msg
def test_dag_paused_true(self):
"""Check to see if a DAG is paused on creation."""
for dag in self.dagbag.dags:
paused = self.dagbag.dags[dag].is_paused_upon_creation
error_msg = f"DAG {dag}: paused not set to True."
assert paused, error_msg
def test_dag_has_tags(self):
"""Test if a DAG is tagged."""
for dag in self.dagbag.dags:
tags = self.dagbag.dags[dag].tags
error_msg = f"DAG {dag}: no tags exist."
assert len(tags) > 0, error_msg
if __name__ == "__main__":
unittest.main()