scripts/in_container/run_migration_reference.py (200 lines of code) (raw):

#!/usr/bin/env python3 # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Module to update db migration information in Airflow """ from __future__ import annotations import os import re import textwrap from collections.abc import Iterable from pathlib import Path from typing import TYPE_CHECKING from alembic.script import ScriptDirectory from rich.console import Console from tabulate import tabulate from airflow import __version__ as airflow_version from airflow.providers.fab import __version__ as fab_version from airflow.utils.db import _get_alembic_config if TYPE_CHECKING: from alembic.script import Script console = Console(width=400, color_system="standard") airflow_version = re.match(r"(\d+\.\d+\.\d+).*", airflow_version).group(1) # type: ignore fab_version = re.match(r"(\d+\.\d+\.\d+).*", fab_version).group(1) # type: ignore project_root = Path(__file__).parents[2].resolve() def replace_text_between(file: Path, start: str, end: str, replacement_text: str): original_text = file.read_text() leading_text = original_text.split(start)[0] trailing_text = original_text.split(end)[1] file.write_text(leading_text + start + replacement_text + end + trailing_text) def wrap_backticks(val): def _wrap_backticks(x): return f"``{x}``" return ",\n".join(map(_wrap_backticks, val)) if isinstance(val, (tuple, list)) else _wrap_backticks(val) def update_doc(file, data, app): replace_text_between( file=file, start=" .. Beginning of auto-generated table\n", end=" .. End of auto-generated table\n", replacement_text="\n" + tabulate( headers={ "revision": "Revision ID", "down_revision": "Revises ID", "version": f"{app.title()} Version", "description": "Description", }, tabular_data=data, tablefmt="grid", stralign="left", disable_numparse=True, ) + "\n\n", ) def has_version(content, app): if app == "airflow": return re.search(r"^airflow_version\s*=.*", content, flags=re.MULTILINE) is not None return re.search(r"^fab_version\s*=.*", content, flags=re.MULTILINE) is not None def insert_version(old_content, file, app): if app == "airflow": new_content = re.sub( r"(^depends_on.*)", lambda x: f'{x.group(1)}\nairflow_version = "{airflow_version}"', old_content, flags=re.MULTILINE, ) else: new_content = re.sub( r"(^depends_on.*)", lambda x: f'{x.group(1)}\nfab_version = "{fab_version}"', old_content, flags=re.MULTILINE, ) file.write_text(new_content) def revision_suffix(rev: Script): if rev.is_head: return " (head)" if rev.is_base: return " (base)" if rev.is_merge_point: return " (merge_point)" if rev.is_branch_point: return " (branch_point)" return "" def ensure_version(revisions: Iterable[Script], app): for rev in revisions: if TYPE_CHECKING: # For mypy assert rev.module.__file__ is not None file = Path(rev.module.__file__) content = file.read_text() if not has_version(content, app=app): insert_version(content, file, app=app) def get_revisions(app="airflow") -> Iterable[Script]: if app == "airflow": config = _get_alembic_config() script = ScriptDirectory.from_config(config) yield from script.walk_revisions() else: from airflow.providers.fab.auth_manager.models.db import FABDBManager script = FABDBManager(session="").get_script_object() yield from script.walk_revisions() def update_docs(revisions: Iterable[Script], app="airflow"): doc_data = [] for rev in revisions: app_revision = rev.module.airflow_version if app == "airflow" else rev.module.fab_version doc_data.append( dict( revision=wrap_backticks(rev.revision) + revision_suffix(rev), down_revision=wrap_backticks(rev.down_revision), version=wrap_backticks(app_revision), # type: ignore description="\n".join(textwrap.wrap(rev.doc, width=60)), ) ) if app == "fab": filepath = project_root / "providers" / "fab" / "docs" / "migrations-ref.rst" else: filepath = project_root / "airflow-core" / "docs" / "migrations-ref.rst" update_doc( file=filepath, data=doc_data, app=app, ) def ensure_mod_prefix(mod_name, idx, version): parts = [f"{idx + 1:04}", *version] match = re.match(r"([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)_(.+)", mod_name) if match: # previously standardized file, rebuild the name parts.append(match.group(5)) else: # new migration file, standard format match = re.match(r"([a-z0-9]+)_(.+)", mod_name) if match: parts.append(match.group(2)) return "_".join(parts) def ensure_filenames_are_sorted(revisions, app): renames = [] is_branched = False unmerged_heads = [] for idx, rev in enumerate(revisions): mod_path = Path(rev.module.__file__) if app == "airflow": version = rev.module.airflow_version.split(".")[0:3] # only first 3 tokens else: version = rev.module.fab_version.split(".")[0:3] # only first 3 tokens correct_mod_basename = ensure_mod_prefix(mod_path.name, idx, version) if mod_path.name != correct_mod_basename: renames.append((mod_path, Path(mod_path.parent, correct_mod_basename))) if is_branched and rev.is_merge_point: is_branched = False if rev.is_branch_point: is_branched = True elif rev.is_head: unmerged_heads.append(rev.revision) if is_branched: head_prefixes = [x[0:4] for x in unmerged_heads] alembic_command = ( "alembic merge -m 'merge heads " + ", ".join(head_prefixes) + "' " + " ".join(unmerged_heads) ) raise SystemExit( "You have multiple alembic heads; please merge them with by running `alembic merge` command under " f'"airflow" directory (where alembic.ini located) and re-run pre-commit. ' f"It should fail once more before succeeding.\nhint: `{alembic_command}`" ) for old, new in renames: os.rename(old, new) def correct_mismatching_revision_nums(revisions: Iterable[Script]): revision_pattern = r'revision = ["\']([a-fA-F0-9]+)["\']' down_revision_pattern = r'down_revision = ["\']([a-fA-F0-9]+)["\']' revision_id_pattern = r"Revision ID: ([a-fA-F0-9]+)" revises_id_pattern = r"Revises: ([a-fA-F0-9]+)" for rev in revisions: if TYPE_CHECKING: # For mypy assert rev.module.__file__ is not None file = Path(rev.module.__file__) content = file.read_text() revision_match = re.search( revision_pattern, content, ) if revision_match is None: raise RuntimeError(f"revision = not found in {file}") revision_id_match = re.search(revision_id_pattern, content) if revision_id_match is None: raise RuntimeError(f"Revision ID: not found in {file}") new_content = content.replace(revision_id_match.group(1), revision_match.group(1), 1) down_revision_match = re.search(down_revision_pattern, new_content) revises_id_match = re.search(revises_id_pattern, new_content) if down_revision_match: if revises_id_match is None: raise RuntimeError(f"Revises: not found in {file}") new_content = new_content.replace(revises_id_match.group(1), down_revision_match.group(1), 1) file.write_text(new_content) if __name__ == "__main__": apps = ["airflow", "fab"] for app in apps: console.print(f"[bright_blue]Updating migration reference for {app}") revisions = list(reversed(list(get_revisions(app)))) console.print(f"[bright_blue]Making sure {app} version updated") ensure_version(revisions=revisions, app=app) console.print("[bright_blue]Making sure there's no mismatching revision numbers") correct_mismatching_revision_nums(revisions=revisions) revisions = list(reversed(list(get_revisions(app=app)))) console.print("[bright_blue]Making sure filenames are sorted") ensure_filenames_are_sorted(revisions=revisions, app=app) revisions = list(get_revisions(app=app)) console.print("[bright_blue]Updating documentation") update_docs(revisions=revisions, app=app) console.print("[green]Migrations OK")