atr/db/__init__.py (555 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import dataclasses import logging import os import re from typing import TYPE_CHECKING, Any, Final, Generic, TypeGuard, TypeVar import sqlalchemy import sqlalchemy.dialects.sqlite import sqlalchemy.ext.asyncio import sqlalchemy.orm as orm import sqlalchemy.sql as sql import sqlmodel import sqlmodel.sql.expression as expression from alembic import command from alembic.config import Config import atr.config as config import atr.db.models as models import atr.user as user import atr.util as util from atr import analysis if TYPE_CHECKING: import datetime import pathlib from collections.abc import Sequence import asfquart.base as base _LOGGER: Final = logging.getLogger(__name__) _global_atr_engine: sqlalchemy.ext.asyncio.AsyncEngine | None = None _global_atr_sessionmaker: sqlalchemy.ext.asyncio.async_sessionmaker | None = None T = TypeVar("T") class NotSet: """ A marker class to indicate that a value is not set and thus should not be considered. This is different to None. """ _instance = None def __new__(cls): # type: ignore if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __repr__(self) -> str: return "<NotSet>" def __copy__(self): # type: ignore return NotSet def __deepcopy__(self, memo: dict[int, Any]): # type: ignore return NotSet NOT_SET: Final[NotSet] = NotSet() type Opt[T] = T | NotSet @dataclasses.dataclass class PathInfo: artifacts: set[pathlib.Path] = dataclasses.field(default_factory=set) errors: dict[pathlib.Path, list[models.CheckResult]] = dataclasses.field(default_factory=dict) metadata: set[pathlib.Path] = dataclasses.field(default_factory=set) # substitutions: dict[pathlib.Path, str] = dataclasses.field(default_factory=dict) successes: dict[pathlib.Path, list[models.CheckResult]] = dataclasses.field(default_factory=dict) # templates: dict[pathlib.Path, str] = dataclasses.field(default_factory=dict) warnings: dict[pathlib.Path, list[models.CheckResult]] = dataclasses.field(default_factory=dict) class Query(Generic[T]): def __init__(self, session: Session, query: expression.SelectOfScalar[T]): self.query = query self.session = session def order_by(self, *args: Any, **kwargs: Any) -> Query[T]: self.query = self.query.order_by(*args, **kwargs) return self async def get(self) -> T | None: result = await self.session.execute(self.query) return result.scalar_one_or_none() async def demand(self, error: Exception) -> T: result = await self.session.execute(self.query) item = result.scalar_one_or_none() if item is None: raise error return item async def all(self) -> Sequence[T]: result = await self.session.execute(self.query) return result.scalars().all() # async def execute(self) -> sqlalchemy.Result[tuple[T]]: # return await self.session.execute(self.query) class Session(sqlalchemy.ext.asyncio.AsyncSession): # TODO: Need to type all of these arguments correctly def check_result( self, id: Opt[int] = NOT_SET, release_name: Opt[str] = NOT_SET, checker: Opt[str] = NOT_SET, primary_rel_path: Opt[str | None] = NOT_SET, created: Opt[datetime.datetime] = NOT_SET, status: Opt[models.CheckResultStatus] = NOT_SET, message: Opt[str] = NOT_SET, data: Opt[Any] = NOT_SET, _release: bool = False, ) -> Query[models.CheckResult]: query = sqlmodel.select(models.CheckResult) if is_defined(id): query = query.where(models.CheckResult.id == id) if is_defined(release_name): query = query.where(models.CheckResult.release_name == release_name) if is_defined(checker): query = query.where(models.CheckResult.checker == checker) if is_defined(primary_rel_path): query = query.where(models.CheckResult.primary_rel_path == primary_rel_path) if is_defined(created): query = query.where(models.CheckResult.created == created) if is_defined(status): query = query.where(models.CheckResult.status == status) if is_defined(message): query = query.where(models.CheckResult.message == message) if is_defined(data): query = query.where(models.CheckResult.data == data) if _release: query = query.options(select_in_load(models.CheckResult.release)) return Query(self, query) def committee( self, name: Opt[str] = NOT_SET, full_name: Opt[str] = NOT_SET, is_podling: Opt[bool] = NOT_SET, parent_committee_name: Opt[str] = NOT_SET, committee_members: Opt[list[str]] = NOT_SET, committers: Opt[list[str]] = NOT_SET, release_managers: Opt[list[str]] = NOT_SET, name_in: Opt[list[str]] = NOT_SET, _projects: bool = False, _public_signing_keys: bool = False, ) -> Query[models.Committee]: query = sqlmodel.select(models.Committee) if is_defined(name): query = query.where(models.Committee.name == name) if is_defined(full_name): query = query.where(models.Committee.full_name == full_name) if is_defined(is_podling): query = query.where(models.Committee.is_podling == is_podling) if is_defined(parent_committee_name): query = query.where(models.Committee.parent_committee_name == parent_committee_name) if is_defined(committee_members): query = query.where(models.Committee.committee_members == committee_members) if is_defined(committers): query = query.where(models.Committee.committers == committers) if is_defined(release_managers): query = query.where(models.Committee.release_managers == release_managers) if is_defined(name_in): models_committee_name = validate_instrumented_attribute(models.Committee.name) query = query.where(models_committee_name.in_(name_in)) if _projects: query = query.options(select_in_load(models.Committee.projects)) if _public_signing_keys: query = query.options(select_in_load(models.Committee.public_signing_keys)) return Query(self, query) async def ns_text_del(self, ns: str, key: str, commit: bool = True) -> None: stmt = sql.delete(models.TextValue).where( validate_instrumented_attribute(models.TextValue.ns) == ns, validate_instrumented_attribute(models.TextValue.key) == key, ) await self.execute(stmt) if commit is True: await self.commit() async def ns_text_del_all(self, ns: str, commit: bool = True) -> None: stmt = sql.delete(models.TextValue).where( validate_instrumented_attribute(models.TextValue.ns) == ns, ) await self.execute(stmt) if commit is True: await self.commit() async def ns_text_get(self, ns: str, key: str) -> str | None: stmt = sql.select(models.TextValue).where( validate_instrumented_attribute(models.TextValue.ns) == ns, validate_instrumented_attribute(models.TextValue.key) == key, ) result = await self.execute(stmt) match result.scalar_one_or_none(): case models.TextValue(value=value): return value case None: return None async def ns_text_set(self, ns: str, key: str, value: str, commit: bool = True) -> None: # Don't use sql.insert(), it won't give on_conflict_do_update() stmt = sqlalchemy.dialects.sqlite.insert(models.TextValue).values((ns, key, value)) stmt = stmt.on_conflict_do_update( index_elements=[models.TextValue.ns, models.TextValue.key], set_=dict(value=value) ) await self.execute(stmt) if commit is True: await self.commit() def project( self, name: Opt[str] = NOT_SET, full_name: Opt[str] = NOT_SET, is_podling: Opt[bool] = NOT_SET, committee_name: Opt[str] = NOT_SET, release_policy_id: Opt[int] = NOT_SET, _committee: bool = False, _releases: bool = False, _distribution_channels: bool = False, _super_project: bool = False, _release_policy: bool = False, _committee_public_signing_keys: bool = False, ) -> Query[models.Project]: query = sqlmodel.select(models.Project) if is_defined(name): query = query.where(models.Project.name == name) if is_defined(full_name): query = query.where(models.Project.full_name == full_name) if is_defined(is_podling): query = query.where(models.Project.is_podling == is_podling) if is_defined(committee_name): query = query.where(models.Project.committee_name == committee_name) if is_defined(release_policy_id): query = query.where(models.Project.release_policy_id == release_policy_id) if _committee: query = query.options(select_in_load(models.Project.committee)) if _releases: query = query.options(select_in_load(models.Project.releases)) if _distribution_channels: query = query.options(select_in_load(models.Project.distribution_channels)) if _super_project: query = query.options(select_in_load(models.Project.super_project)) if _release_policy: query = query.options(select_in_load(models.Project.release_policy)) if _committee_public_signing_keys: query = query.options(select_in_load_nested(models.Project.committee, models.Committee.public_signing_keys)) return Query(self, query) def public_signing_key( self, fingerprint: Opt[str] = NOT_SET, algorithm: Opt[str] = NOT_SET, length: Opt[int] = NOT_SET, created: Opt[datetime.datetime] = NOT_SET, expires: Opt[datetime.datetime | None] = NOT_SET, declared_uid: Opt[str | None] = NOT_SET, apache_uid: Opt[str] = NOT_SET, ascii_armored_key: Opt[str] = NOT_SET, _committees: bool = False, ) -> Query[models.PublicSigningKey]: query = sqlmodel.select(models.PublicSigningKey) if is_defined(fingerprint): query = query.where(models.PublicSigningKey.fingerprint == fingerprint) if is_defined(algorithm): query = query.where(models.PublicSigningKey.algorithm == algorithm) if is_defined(length): query = query.where(models.PublicSigningKey.length == length) if is_defined(created): query = query.where(models.PublicSigningKey.created == created) if is_defined(expires): query = query.where(models.PublicSigningKey.expires == expires) if is_defined(declared_uid): query = query.where(models.PublicSigningKey.declared_uid == declared_uid) if is_defined(apache_uid): query = query.where(models.PublicSigningKey.apache_uid == apache_uid) if is_defined(ascii_armored_key): query = query.where(models.PublicSigningKey.ascii_armored_key == ascii_armored_key) if _committees: query = query.options(select_in_load(models.PublicSigningKey.committees)) return Query(self, query) def release( self, name: Opt[str] = NOT_SET, stage: Opt[models.ReleaseStage] = NOT_SET, phase: Opt[models.ReleasePhase] = NOT_SET, created: Opt[datetime.datetime] = NOT_SET, project_name: Opt[str] = NOT_SET, package_managers: Opt[list[str]] = NOT_SET, version: Opt[str] = NOT_SET, revision: Opt[str] = NOT_SET, sboms: Opt[list[str]] = NOT_SET, release_policy_id: Opt[int] = NOT_SET, votes: Opt[list[models.VoteEntry]] = NOT_SET, _project: bool = False, _release_policy: bool = False, _committee: bool = False, _tasks: bool = False, ) -> Query[models.Release]: query = sqlmodel.select(models.Release) if is_defined(name): query = query.where(models.Release.name == name) if is_defined(stage): query = query.where(models.Release.stage == stage) if is_defined(phase): query = query.where(models.Release.phase == phase) if is_defined(created): query = query.where(models.Release.created == created) if is_defined(project_name): query = query.where(models.Release.project_name == project_name) if is_defined(package_managers): query = query.where(models.Release.package_managers == package_managers) if is_defined(version): query = query.where(models.Release.version == version) if is_defined(revision): query = query.where(models.Release.revision == revision) if is_defined(sboms): query = query.where(models.Release.sboms == sboms) if is_defined(release_policy_id): query = query.where(models.Release.release_policy_id == release_policy_id) if is_defined(votes): query = query.where(models.Release.votes == votes) if _project: query = query.options(select_in_load(models.Release.project)) if _release_policy: query = query.options(select_in_load(models.Release.release_policy)) if _committee: query = query.options(select_in_load_nested(models.Release.project, models.Project.committee)) if _tasks: query = query.options(select_in_load(models.Release.tasks)) return Query(self, query) def release_policy( self, id: Opt[int] = NOT_SET, mailto_addresses: Opt[list[str]] = NOT_SET, manual_vote: Opt[bool] = NOT_SET, min_hours: Opt[int] = NOT_SET, release_checklist: Opt[str] = NOT_SET, pause_for_rm: Opt[bool] = NOT_SET, _project: bool = False, ) -> Query[models.ReleasePolicy]: query = sqlmodel.select(models.ReleasePolicy) if is_defined(id): query = query.where(models.ReleasePolicy.id == id) if is_defined(mailto_addresses): query = query.where(models.ReleasePolicy.mailto_addresses == mailto_addresses) if is_defined(manual_vote): query = query.where(models.ReleasePolicy.manual_vote == manual_vote) if is_defined(min_hours): query = query.where(models.ReleasePolicy.min_hours == min_hours) if is_defined(release_checklist): query = query.where(models.ReleasePolicy.release_checklist == release_checklist) if is_defined(pause_for_rm): query = query.where(models.ReleasePolicy.pause_for_rm == pause_for_rm) if _project: query = query.options(select_in_load(models.ReleasePolicy.project)) return Query(self, query) def ssh_key( self, fingerprint: Opt[str] = NOT_SET, key: Opt[str] = NOT_SET, asf_uid: Opt[str] = NOT_SET, ) -> Query[models.SSHKey]: query = sqlmodel.select(models.SSHKey) if is_defined(fingerprint): query = query.where(models.SSHKey.fingerprint == fingerprint) if is_defined(key): query = query.where(models.SSHKey.key == key) if is_defined(asf_uid): query = query.where(models.SSHKey.asf_uid == asf_uid) return Query(self, query) def task( self, id: Opt[int] = NOT_SET, status: Opt[models.TaskStatus] = NOT_SET, task_type: Opt[str] = NOT_SET, task_args: Opt[Any] = NOT_SET, added: Opt[datetime.datetime] = NOT_SET, started: Opt[datetime.datetime | None] = NOT_SET, pid: Opt[int | None] = NOT_SET, completed: Opt[datetime.datetime | None] = NOT_SET, result: Opt[Any | None] = NOT_SET, error: Opt[str | None] = NOT_SET, release_name: Opt[str | None] = NOT_SET, _release: bool = False, ) -> Query[models.Task]: query = sqlmodel.select(models.Task) if is_defined(id): query = query.where(models.Task.id == id) if is_defined(status): query = query.where(models.Task.status == status) if is_defined(task_type): query = query.where(models.Task.task_type == task_type) if is_defined(task_args): query = query.where(models.Task.task_args == task_args) if is_defined(added): query = query.where(models.Task.added == added) if is_defined(started): query = query.where(models.Task.started == started) if is_defined(pid): query = query.where(models.Task.pid == pid) if is_defined(completed): query = query.where(models.Task.completed == completed) if is_defined(result): query = query.where(models.Task.result == result) if is_defined(error): query = query.where(models.Task.error == error) if is_defined(release_name): query = query.where(models.Task.release_name == release_name) if _release: query = query.options(select_in_load(models.Task.release)) return Query(self, query) def text_value( self, ns: Opt[str] = NOT_SET, key: Opt[str] = NOT_SET, value: Opt[str] = NOT_SET, ) -> Query[models.TextValue]: query = sqlmodel.select(models.TextValue) if is_defined(ns): query = query.where(models.TextValue.ns == ns) if is_defined(key): query = query.where(models.TextValue.key == key) if is_defined(value): query = query.where(models.TextValue.value == value) return Query(self, query) async def create_async_engine(app_config: type[config.AppConfig]) -> sqlalchemy.ext.asyncio.AsyncEngine: absolute_db_path = os.path.join(app_config.STATE_DIR, app_config.SQLITE_DB_PATH) # Three slashes are required before either a relative or absolute path sqlite_url = f"sqlite+aiosqlite:///{absolute_db_path}" # Use aiosqlite for async SQLite access engine = sqlalchemy.ext.asyncio.create_async_engine( sqlite_url, connect_args={ "check_same_thread": False, "timeout": 30, }, ) # Set SQLite pragmas for better performance # Use 64 MB for the cache_size, and 5000ms for busy_timeout async with engine.begin() as conn: await conn.execute(sql.text("PRAGMA journal_mode=WAL")) await conn.execute(sql.text("PRAGMA synchronous=NORMAL")) await conn.execute(sql.text("PRAGMA cache_size=-64000")) await conn.execute(sql.text("PRAGMA foreign_keys=ON")) await conn.execute(sql.text("PRAGMA busy_timeout=5000")) return engine async def get_project_release_policy(data: Session, project_name: str) -> models.ReleasePolicy | None: """Fetch the ReleasePolicy for a project.""" project = await data.project(name=project_name, _release_policy=True).demand( RuntimeError(f"Project {project_name} not found") ) return project.release_policy def init_database(app: base.QuartApp) -> None: """ Creates and initializes the database for a QuartApp. The database is created and an AsyncSession is registered as extension for the app. Any pending migrations are executed. """ @app.before_serving async def create() -> None: global _global_atr_engine, _global_atr_sessionmaker app_config = config.get() engine = await create_async_engine(app_config) _global_atr_engine = engine _global_atr_sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker( bind=engine, class_=Session, expire_on_commit=False ) # Run any pending migrations on startup _LOGGER.info("Applying database migrations via init_database...") alembic_ini_path = os.path.join(app_config.PROJECT_ROOT, "alembic.ini") alembic_cfg = Config(alembic_ini_path) # Construct synchronous URLs absolute_db_path = os.path.join(app_config.STATE_DIR, app_config.SQLITE_DB_PATH) sync_sqlalchemy_url = f"sqlite:///{absolute_db_path}" _LOGGER.info(f"Setting Alembic URL for command: {sync_sqlalchemy_url}") alembic_cfg.set_main_option("sqlalchemy.url", sync_sqlalchemy_url) # Ensure that Alembic finds the migrations directory relative to project root migrations_dir_path = os.path.join(app_config.PROJECT_ROOT, "migrations") _LOGGER.info(f"Setting Alembic script_location for command: {migrations_dir_path}") alembic_cfg.set_main_option("script_location", migrations_dir_path) try: _LOGGER.info("Running alembic upgrade head...") command.upgrade(alembic_cfg, "head") _LOGGER.info("Database migrations applied successfully") except Exception: _LOGGER.exception("Failed to apply database migrations during startup") raise try: _LOGGER.info("Running alembic check...") command.check(alembic_cfg) _LOGGER.info("Alembic check passed: DB schema matches models") except Exception: _LOGGER.exception("Failed to check database migrations during startup") raise async def init_database_for_worker() -> None: global _global_atr_engine, _global_atr_sessionmaker _LOGGER.info(f"Creating database for worker {os.getpid()}") engine = await create_async_engine(config.get()) _global_atr_engine = engine _global_atr_sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker( bind=engine, class_=Session, expire_on_commit=False ) def is_defined(v: T | NotSet) -> TypeGuard[T]: return not isinstance(v, NotSet) def is_undefined(v: object | NotSet) -> TypeGuard[NotSet]: return isinstance(v, NotSet) # async def recent_tasks(data: Session, release_name: str, file_path: str, modified: int) -> dict[str, models.Task]: # """Get the most recent task for each task type for a specific file.""" # tasks = await data.task( # release_name=release_name, # path=str(file_path), # modified=modified, # ).all() # # # Group by task_type and keep the most recent one # # We use the highest id to determine the most recent task # recent_tasks: dict[str, models.Task] = {} # for task in tasks: # # If we haven't seen this task type before or if this task is newer # if (task.task_type.value not in recent_tasks) or (task.id > recent_tasks[task.task_type.value].id): # recent_tasks[task.task_type.value] = task # # return recent_tasks async def path_info(release: models.Release, paths: list[pathlib.Path]) -> PathInfo: info = PathInfo() for path in paths: # Get template and substitutions # elements = { # "core": release.project.name, # "version": release.version, # "sub": None, # "template": None, # "substitutions": None, # } # template, substitutions = analysis.filename_parse(str(path), elements) # info.templates[path] = template # info.substitutions[path] = analysis.substitutions_format(substitutions) or "none" # Get artifacts and metadata search = re.search(analysis.extension_pattern(), str(path)) if search: if search.group("artifact"): info.artifacts.add(path) elif search.group("metadata"): info.metadata.add(path) # Get successes, warnings, and errors async with session() as data: info.successes[path] = list( await data.check_result( release_name=release.name, primary_rel_path=str(path), status=models.CheckResultStatus.SUCCESS ).all() ) info.warnings[path] = list( await data.check_result( release_name=release.name, primary_rel_path=str(path), status=models.CheckResultStatus.WARNING ).all() ) info.errors[path] = list( await data.check_result( release_name=release.name, primary_rel_path=str(path), status=models.CheckResultStatus.FAILURE ).all() ) return info def select_in_load(*entities: Any) -> orm.strategy_options._AbstractLoad: """Eagerly load the given entities from the query.""" validated_entities = [] for entity in entities: if not isinstance(entity, orm.InstrumentedAttribute): raise ValueError(f"Object must be an orm.InstrumentedAttribute, got: {type(entity)}") validated_entities.append(entity) return orm.selectinload(*validated_entities) def select_in_load_nested(parent: Any, *descendants: Any) -> orm.strategy_options._AbstractLoad: """Eagerly load the given nested entities from the query.""" if not isinstance(parent, orm.InstrumentedAttribute): raise ValueError(f"Parent must be an orm.InstrumentedAttribute, got: {type(parent)}") for descendant in descendants: if not isinstance(descendant, orm.InstrumentedAttribute): raise ValueError(f"Descendant must be an orm.InstrumentedAttribute, got: {type(descendant)}") result = orm.selectinload(parent) for descendant in descendants: result = result.selectinload(descendant) return result def session() -> Session: """Create a new asynchronous database session.""" # FIXME: occasionally you see this in the console output # <sys>:0: SAWarning: The garbage collector is trying to clean up non-checked-in connection <AdaptedConnection # <Connection(Thread-291, started daemon 138838634661440)>>, which will be dropped, as it cannot be safely # terminated. Please ensure that SQLAlchemy pooled connections are returned to the pool explicitly, either by # calling ``close()`` or by using appropriate context managers to manage their lifecycle. # Not fully clear where this is coming from, but we could experiment by returning a session like that: # async def session() -> AsyncGenerator[Session, None]: # async with _global_atr_sessionmaker() as session: # yield session # from FastAPI documentation: https://fastapi-users.github.io/fastapi-users/latest/configuration/databases/sqlalchemy/ if _global_atr_sessionmaker is None: raise RuntimeError("database not initialized") else: return util.validate_as_type(_global_atr_sessionmaker(), Session) async def shutdown_database() -> None: if _global_atr_engine: _LOGGER.info("Closing database") await _global_atr_engine.dispose() else: _LOGGER.info("No database to close") async def tasks_ongoing(project_name: str, version_name: str, draft_revision: str) -> int: release_name = models.release_name(project_name, version_name) async with session() as data: query = ( sqlmodel.select(sqlalchemy.func.count()) .select_from(models.Task) .where( models.Task.release_name == release_name, models.Task.draft_revision == draft_revision, validate_instrumented_attribute(models.Task.status).in_( [models.TaskStatus.QUEUED, models.TaskStatus.ACTIVE] ), ) ) result = await data.execute(query) return result.scalar_one() async def unfinished_releases(asfuid: str) -> dict[str, list[models.Release]]: releases: dict[str, list[models.Release]] = {} async with session() as data: user_projects = await user.projects(asfuid) user_projects.sort(key=lambda p: p.display_name) active_phases = [ models.ReleasePhase.RELEASE_CANDIDATE_DRAFT, models.ReleasePhase.RELEASE_CANDIDATE, models.ReleasePhase.RELEASE_PREVIEW, ] for project in user_projects: stmt = ( sqlmodel.select(models.Release) .where( models.Release.project_name == project.name, validate_instrumented_attribute(models.Release.phase).in_(active_phases), ) .options(select_in_load(models.Release.project)) .order_by(validate_instrumented_attribute(models.Release.created).desc()) ) result = await data.execute(stmt) active_releases = list(result.scalars().all()) if active_releases: active_releases.sort(key=lambda r: r.created, reverse=True) releases[project.short_display_name] = active_releases return releases def validate_instrumented_attribute(obj: Any) -> orm.InstrumentedAttribute: """Check if the given object is an InstrumentedAttribute.""" if not isinstance(obj, orm.InstrumentedAttribute): raise ValueError(f"Object must be an orm.InstrumentedAttribute, got: {type(obj)}") return obj