atr/tasks/checks/__init__.py (123 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import dataclasses
import datetime
import functools
import pathlib
from typing import TYPE_CHECKING, Any
import sqlmodel
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
import pydantic
import atr.db as db
import atr.db.models as models
import atr.util as util
# Pydantic does not like Callable types, so we use a dataclass instead
# It says: "you should define `Callable`, then call `FunctionArguments.model_rebuild()`"
@dataclasses.dataclass
class FunctionArguments:
recorder: Callable[[], Awaitable[Recorder]]
release_name: str
draft_revision: str
primary_rel_path: str | None
extra_args: dict[str, Any]
class Recorder:
checker: str
release_name: str
project_name: str
version_name: str
primary_rel_path: str | None
draft_revision: str
afresh: bool
def __init__(
self,
checker: str | Callable[..., Any],
release_name: str,
draft_revision: str,
primary_rel_path: str | None = None,
afresh: bool = True,
) -> None:
self.checker = function_key(checker) if callable(checker) else checker
self.release_name = release_name
self.draft_revision = draft_revision
self.primary_rel_path = primary_rel_path
self.afresh = afresh
self.constructed = False
project_name, version_name = models.project_version(release_name)
self.project_name = project_name
self.version_name = version_name
@classmethod
async def create(
cls,
checker: str | Callable[..., Any],
release_name: str,
draft_revision: str,
primary_rel_path: str | None = None,
afresh: bool = True,
) -> Recorder:
recorder = cls(checker, release_name, draft_revision, primary_rel_path, afresh)
if afresh is True:
# Clear outer path whether it's specified or not
await recorder.clear(primary_rel_path)
recorder.constructed = True
return recorder
async def _add(
self, status: models.CheckResultStatus, message: str, data: Any, primary_rel_path: str | None = None
) -> models.CheckResult:
if self.constructed is False:
raise RuntimeError("Cannot add check result to a recorder that has not been constructed")
if primary_rel_path is not None:
if self.primary_rel_path is not None:
raise ValueError("Cannot specify path twice")
if self.afresh is True:
# Clear inner path only if it's specified
await self.clear(primary_rel_path)
result = models.CheckResult(
release_name=self.release_name,
checker=self.checker,
primary_rel_path=primary_rel_path or self.primary_rel_path,
created=datetime.datetime.now(datetime.UTC),
status=status,
message=message,
data=data,
)
# It would be more efficient to keep a session open
# But, we prefer in this case to maintain a simpler interface
# If performance is unacceptable, we can revisit this design
async with db.session() as session:
session.add(result)
await session.commit()
return result
async def abs_path(self, rel_path: str | None = None) -> pathlib.Path | None:
"""Construct the absolute path using the required draft_revision."""
base_dir = util.get_unfinished_dir()
project_part = self.project_name
version_part = self.version_name
revision_part = self.draft_revision
# Determine the relative path part
rel_path_part: str | None = None
if rel_path is not None:
rel_path_part = rel_path
elif self.primary_rel_path is not None:
rel_path_part = self.primary_rel_path
# Construct the absolute path
abs_path_parts: list[str | pathlib.Path] = [base_dir, project_part, version_part, revision_part]
if isinstance(rel_path_part, str):
abs_path_parts.append(rel_path_part)
return pathlib.Path(*abs_path_parts)
async def clear(self, primary_rel_path: str | None = None) -> None:
async with db.session() as data:
stmt = sqlmodel.delete(models.CheckResult).where(
db.validate_instrumented_attribute(models.CheckResult.release_name) == self.release_name,
db.validate_instrumented_attribute(models.CheckResult.checker) == self.checker,
db.validate_instrumented_attribute(models.CheckResult.primary_rel_path) == primary_rel_path,
)
await data.execute(stmt)
await data.commit()
async def exception(self, message: str, data: Any, primary_rel_path: str | None = None) -> models.CheckResult:
return await self._add(models.CheckResultStatus.EXCEPTION, message, data, primary_rel_path=primary_rel_path)
async def failure(self, message: str, data: Any, primary_rel_path: str | None = None) -> models.CheckResult:
return await self._add(models.CheckResultStatus.FAILURE, message, data, primary_rel_path=primary_rel_path)
async def success(self, message: str, data: Any, primary_rel_path: str | None = None) -> models.CheckResult:
return await self._add(models.CheckResultStatus.SUCCESS, message, data, primary_rel_path=primary_rel_path)
async def warning(self, message: str, data: Any, primary_rel_path: str | None = None) -> models.CheckResult:
return await self._add(models.CheckResultStatus.WARNING, message, data, primary_rel_path=primary_rel_path)
def function_key(func: Callable[..., Any]) -> str:
return func.__module__ + "." + func.__name__
def with_model(cls: type[pydantic.BaseModel]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator to specify the parameters for a check."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
async def wrapper(data_dict: dict[str, Any], *args: Any, **kwargs: Any) -> Any:
model_instance = cls(**data_dict)
return await func(model_instance, *args, **kwargs)
return wrapper
return decorator