migrations/env.py (104 lines of code) (raw):

import datetime import logging.config import os import re import subprocess import sys from collections.abc import Iterable from typing import Literal import alembic import alembic.autogenerate.api as api import alembic.operations.ops as ops import alembic.runtime.migration as migration import sqlalchemy import sqlmodel import sqlmodel.sql.sqltypes as sqltypes # Add the project root to the Python path # This script must be at migrations/env.py for this to work project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if project_root not in sys.path: sys.path.insert(0, project_root) # Use database metadata from ATR directly import atr.config # Populate SQLModel.metadata as a side effect of importing the models import atr.db.models # this is the Alembic Config object, which provides # access to the values within the .ini file in use. alembic_config = alembic.context.config # Interpret the config file for Python logging. # This line sets up loggers basically. if alembic_config.config_file_name is not None: logging.config.fileConfig(alembic_config.config_file_name) # The SQLModel.metadata object as populated by the ATR models target_metadata = sqlmodel.SQLModel.metadata # Get the database path from application configuration app_config = atr.config.get() absolute_db_path = os.path.join(app_config.STATE_DIR, app_config.SQLITE_DB_PATH) # Construct the synchronous SQLite URL using the absolute path # Three slashes come before any absolute or relative path sync_sqlalchemy_url = f"sqlite:///{absolute_db_path}" def get_short_commit_hash(project_root_path: str) -> str: """Get an eight character git commit hash, or a fallback.""" try: process = subprocess.run( ["git", "rev-parse", "--short=8", "HEAD"], capture_output=True, text=True, cwd=project_root_path, check=True, ) return process.stdout.strip() except (subprocess.CalledProcessError, FileNotFoundError): # Return a placeholder if the git command fails return "00000000" def process_revision_directives_custom_naming( context: migration.MigrationContext, revision: str | Iterable[str | None] | Iterable[str], directives: list[ops.MigrationScript], ) -> None: """Generate revision IDs and filenames like NNNN_YYYY.MM.DD_COMMITSHORT.py.""" global project_root if context.script is None: raise RuntimeError("MigrationContext.script is None, cannot determine script directory") versions_path = os.path.join(context.script.dir, "versions") if not os.path.exists(versions_path): os.makedirs(versions_path) highest_num = 0 pattern = re.compile(r"^(\d{4})_.*\.py$") try: for fname in os.listdir(versions_path): match = pattern.match(fname) if match: highest_num = max(highest_num, int(match.group(1))) except Exception as e: print(f"Warning: Error scanning versions directory '{versions_path}': {e!r}") next_num_str = f"{highest_num + 1:04d}" date_str = datetime.date.today().strftime("%Y.%m.%d") commit_short = get_short_commit_hash(project_root) new_rev_id = f"{next_num_str}_{date_str}_{commit_short}" calculated_path = os.path.join(versions_path, f"{new_rev_id}.py") for directive in directives: setattr(directive, "rev_id", new_rev_id) setattr(directive, "path", calculated_path) def render_item_override(type_: str, item: object, autogen_context: api.AutogenContext) -> str | Literal[False]: """Apply custom rendering for SQLModel AutoString. Prevents autogenerate from rendering <AutoString>. Returns False to indicate no handler for other types. """ # Add import for sqlalchemy as sa if not present autogen_context.imports.add("import sqlalchemy as sa") if (type_ == "type") and isinstance(item, sqltypes.AutoString): # Render sqlmodel.sql.sqltypes.AutoString as sa.String() return "sa.String()" # Default rendering for other types return False def run_migrations_offline() -> None: """Run migrations in 'offline' mode.""" alembic.context.configure( url=sync_sqlalchemy_url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, render_item=render_item_override, process_revision_directives=process_revision_directives_custom_naming, ) with alembic.context.begin_transaction(): alembic.context.run_migrations() def run_migrations_online() -> None: """Run migrations in 'online' mode.""" configuration = alembic_config.get_section(alembic_config.config_ini_section) or {} configuration["sqlalchemy.url"] = sync_sqlalchemy_url connectable = sqlalchemy.engine_from_config( configuration, prefix="sqlalchemy.", poolclass=sqlalchemy.pool.NullPool, ) with connectable.connect() as connection: alembic.context.configure( connection=connection, target_metadata=target_metadata, render_item=render_item_override, process_revision_directives=process_revision_directives_custom_naming, ) with alembic.context.begin_transaction(): alembic.context.run_migrations() if alembic.context.is_offline_mode(): run_migrations_offline() else: run_migrations_online()