dags/map_reproducibility/internal_runs/backfill_dags.py (152 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Factory function to create Aotc reproducibility backfill DAGs with adaptive, split-image grouping based on provided model configurations. """ import datetime import os from collections import defaultdict import logging from typing import Dict, Any, List, Optional from airflow import models from airflow.operators.empty import EmptyOperator from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule # Assuming these are accessible in your Airflow environment from dags.map_reproducibility.utils.constants import Image from dags.map_reproducibility.internal_runs.dag_configs import DAG_CONFIGS_ULTRA, DAG_CONFIGS_MEGA from dags.map_reproducibility.utils.internal_aotc_workload import run_internal_aotc_workload # --- Default Configuration --- # These can be overridden when calling the factory function if needed DEFAULT_TEST_RUN = False DEFAULT_BACKFILL = True # --- Image Tag Generation --- # Using today's date by default: default_utc_date = "2025-04-08" logging.info(f"Default UTC date for image tags: {default_utc_date}") # You can override these tags when calling the factory function DEFAULT_NIGHTLY_IMAGE_TAG = ( f"{Image.MAXTEXT_JAX_STABLE_NIGHTLY}:{default_utc_date}" ) DEFAULT_RELEASE_IMAGE_TAG = ( f"{Image.MAXTEXT_JAX_STABLE_RELEASE}:{default_utc_date}" ) # --- Base DAG Tags --- BASE_DAG_TAGS = [ "reproducibility", "experimental", "xlml", "v1.17", "internal", "regressiontests", "backfill", ] # --- DAG Factory Function --- def create_adaptive_backfill_dag( dag_id: str, model_configs: Dict[str, Dict[str, Any]], start_date: datetime.datetime, dag_tags: Optional[List[str]] = None, schedule: Optional[str] = None, # Default None for backfill test_run: bool = DEFAULT_TEST_RUN, backfill: bool = DEFAULT_BACKFILL, nightly_image_tag: str = DEFAULT_NIGHTLY_IMAGE_TAG, release_image_tag: str = DEFAULT_RELEASE_IMAGE_TAG, retries: int = 2, # Number of retries for tasks ) -> models.DAG: """ Creates an Airflow DAG for backfilling Aotc reproducibility benchmarks. Features: - Sequentially runs tasks in numbered groups. - Allows assigning different group numbers for nightly vs. release images for the same base configuration via 'backfill_group_nightly'/'backfill_group_release' keys. - Dynamically chains groups based on assigned numbers. - Continues execution of subsequent groups even if tasks in previous groups fail. Args: dag_id: Unique identifier for the DAG. model_configs: Dictionary where keys are relative config YAML paths and values are dictionaries containing 'timeout_minutes', 'backfill_group_nightly' (int), and 'backfill_group_release' (int). start_date: The DAG's start date. dag_tags: Optional list of tags to add to the base tags. schedule: Airflow schedule interval (defaults to None). test_run: If True, potentially runs smaller test versions of tasks. backfill: Flag indicating if this is for backfill purposes. nightly_image_tag: Full tag for the nightly image to use. release_image_tag: Full tag for the release image to use. retries: Number of retries for each task before marking as failed. Returns: An Airflow DAG object. """ effective_tags = BASE_DAG_TAGS + (dag_tags or []) # Define default_args to set retries for all tasks default_args = { "retries": retries, "retry_delay": datetime.timedelta(minutes=2), } with models.DAG( dag_id=dag_id, schedule=schedule, tags=effective_tags, start_date=start_date, catchup=False, # Important for backfills default_args=default_args, ) as dag: start = EmptyOperator(task_id="start") end = EmptyOperator( task_id="end", # This ensures the end task will run regardless of upstream failures trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, ) # Organize configs by group number - UPDATED TO PREVENT DUPLICATES group_configs = defaultdict(list) # Track which config+image combinations have been added to which groups processed_configs = set() for config_path, config_info in model_configs.items(): nightly_group = config_info.get("backfill_group_nightly") if isinstance(nightly_group, int): # Create a unique key for this config+image+group combination config_key = (config_path, "nightly", nightly_group) if config_key not in processed_configs: group_configs[nightly_group].append( (config_path, config_info, "nightly") ) processed_configs.add(config_key) release_group = config_info.get("backfill_group_release") if isinstance(release_group, int): # Create a unique key for this config+image+group combination config_key = (config_path, "release", release_group) if config_key not in processed_configs: group_configs[release_group].append( (config_path, config_info, "release") ) processed_configs.add(config_key) # Get sorted group numbers sorted_group_numbers = sorted(group_configs.keys()) if not sorted_group_numbers: logging.warning( f"[{dag_id}] No tasks found in any group. Linking start >> end." ) start >> end return dag logging.info( f"[{dag_id}] Found configurations for groups: {sorted_group_numbers}" ) # Create TaskGroups for each group number task_groups = {} group_gateways = {} # Store gateway tasks for each group for idx, group_num in enumerate(sorted_group_numbers): group_task_id = f"group_{group_num}" # Create a TaskGroup for this group number with TaskGroup(group_id=group_task_id) as tg: # Create start and end gateways for this group group_start = EmptyOperator( task_id=f"group_{group_num}_start", trigger_rule=TriggerRule.ALL_SUCCESS, # Default behavior ) group_end = EmptyOperator( task_id=f"group_{group_num}_end", # This allows the group to complete even if some tasks fail trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, ) # Remember these gateways group_gateways[group_num] = (group_start, group_end) # Create tasks within this group - UPDATED TO PREVENT DUPLICATES task_list = [] # Process each config with its specific image type for config_path, config_info, image_type in sorted( group_configs[group_num], key=lambda x: (x[0], x[2]), # Sort by config path and image type ): config_name = os.path.basename(config_path).replace(".yaml", "") timeout = config_info.get("timeout_minutes", 60) # Determine image tag based on type image_tag = ( nightly_image_tag if image_type == "nightly" else release_image_tag ) custom_task_id = f"{config_name}_{image_type}" # Create task within the TaskGroup task = run_internal_aotc_workload.override(task_id=custom_task_id)( relative_config_yaml_path=config_path, test_run=test_run, backfill=backfill, timeout=timeout, image_version=image_tag, ) task_list.append(task) # Set up the task dependencies within the group if task_list: # Connect all tasks to the group start and end for task in task_list: group_start >> task >> group_end else: # If no tasks, just link start and end group_start >> group_end # Store TaskGroup reference task_groups[group_num] = tg # Set up dependencies between groups and start/end ( start >> group_gateways[sorted_group_numbers[0]][0] ) # Start to first group's start # Chain TaskGroups in sequence with appropriate trigger rules for i in range(len(sorted_group_numbers) - 1): current_group = sorted_group_numbers[i] next_group = sorted_group_numbers[i + 1] # Connect current group's end to next group's start # The next group will start even if some tasks in the current group fail group_gateways[current_group][1] >> group_gateways[next_group][0] # Connect last group's end to DAG end group_gateways[sorted_group_numbers[-1]][1] >> end return dag # --- Instantiate specific DAGs using the factory --- # a3ultra backfill DAG dag1 = create_adaptive_backfill_dag( dag_id="new_internal_backill_a3ultra", model_configs=DAG_CONFIGS_ULTRA, start_date=datetime.datetime(2025, 4, 11), dag_tags=BASE_DAG_TAGS, ) dag2 = create_adaptive_backfill_dag( dag_id="new_internal_backill_a3mega", model_configs=DAG_CONFIGS_MEGA, start_date=datetime.datetime(2025, 4, 11), dag_tags=BASE_DAG_TAGS, )