atr/routes/__init__.py (286 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import functools
import logging
import time
from typing import TYPE_CHECKING, Any, Final, NoReturn, ParamSpec, Protocol, TypeVar
import aiofiles
import aiofiles.os
import asfquart
import asfquart.auth as auth
import asfquart.base as base
import asfquart.session as session
import quart
import atr.db as db
import atr.db.models as models
import atr.user as user
import atr.util as util
from atr import config
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Coroutine, Sequence
import werkzeug.datastructures as datastructures
import werkzeug.wrappers.response as response
if asfquart.APP is ...:
raise RuntimeError("APP is not set")
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
T = TypeVar("T")
# TODO: Should get this from config, checking debug there
_MEASURE_PERFORMANCE: Final[bool] = True
_LOGGER: Final = logging.getLogger(__name__)
# | 1 | RSA (Encrypt or Sign) [HAC] |
# | 2 | RSA Encrypt-Only [HAC] |
# | 3 | RSA Sign-Only [HAC] |
# | 16 | Elgamal (Encrypt-Only) [ELGAMAL] [HAC] |
# | 17 | DSA (Digital Signature Algorithm) [FIPS186] [HAC] |
# | 18 | ECDH public key algorithm |
# | 19 | ECDSA public key algorithm [FIPS186] |
# | 20 | Reserved (formerly Elgamal Encrypt or Sign) |
# | 21 | Reserved for Diffie-Hellman |
# | | (X9.42, as defined for IETF-S/MIME) |
# | 22 | EdDSA [I-D.irtf-cfrg-eddsa] |
# - https://lists.gnupg.org/pipermail/gnupg-devel/2017-April/032762.html
# TODO: (Obviously we should move this, but where to?)
algorithms: Final[dict[int, str]] = {
1: "RSA",
2: "RSA",
3: "RSA",
16: "Elgamal",
17: "DSA",
18: "ECDH",
19: "ECDSA",
21: "Diffie-Hellman",
22: "EdDSA",
}
class AsyncFileHandler(logging.Handler):
"""A logging handler that writes logs asynchronously using aiofiles."""
def __init__(self, filename, mode="w", encoding=None):
super().__init__()
self.filename = filename
if mode != "w":
raise RuntimeError("Only write mode is supported")
self.encoding = encoding
self.queue = asyncio.Queue()
self.our_worker_task = None
def our_worker_task_ensure(self):
"""Lazily create the worker task if it doesn't exist and there's an event loop."""
if self.our_worker_task is None:
try:
loop = asyncio.get_running_loop()
self.our_worker_task = loop.create_task(self.our_worker())
except RuntimeError:
# No event loop running yet, try again on next emit
...
async def our_worker(self):
"""Background task that writes queued log messages to file."""
# Use a binary mode literal with aiofiles.open
# https://github.com/Tinche/aiofiles/blob/main/src/aiofiles/threadpool/__init__.py
# We should be able to use any mode, but pyright requires a binary mode
async with aiofiles.open(self.filename, "wb+") as f:
while True:
record = await self.queue.get()
if record is None:
break
try:
# Format the log record first
formatted_message = self.format(record) + "\n"
message_bytes = formatted_message.encode(self.encoding or "utf-8")
await f.write(message_bytes)
await f.flush()
except Exception:
self.handleError(record)
finally:
self.queue.task_done()
def emit(self, record):
"""Queue the record for writing by the worker task."""
try:
# Ensure worker task is running
self.our_worker_task_ensure()
# Queue the record, but handle the case where no event loop is running yet
try:
self.queue.put_nowait(record)
except RuntimeError:
...
except Exception:
self.handleError(record)
def close(self):
"""Shut down the worker task cleanly."""
if self.our_worker_task is not None and not self.our_worker_task.done():
try:
self.queue.put_nowait(None)
except RuntimeError:
# No running event loop, no need to clean up
...
super().close()
# This is the type of functions to which we apply @committer_get
# In other words, functions which accept CommitterSession as their first arg
class CommitterRouteHandler(Protocol[R]):
"""Protocol for @committer_get decorated functions."""
__name__: str
__doc__: str | None
def __call__(self, session: CommitterSession, *args: Any, **kwargs: Any) -> Awaitable[R]: ...
class CommitterSession:
"""Session with extra information about committers."""
def __init__(self, web_session: session.ClientSession) -> None:
self._projects: list[models.Project] | None = None
self._session = web_session
def __getattr__(self, name: str) -> Any:
# TODO: Not type safe, should subclass properly if possible
# For example, we can access session.no_such_attr and the type checkers won't notice
return getattr(self._session, name)
async def check_access(self, project_name: str) -> None:
if not any((p.name == project_name) for p in (await self.user_projects)):
raise base.ASFQuartException("You do not have access to this project", errorcode=403)
@property
def app_host(self) -> str:
return config.get().APP_HOST
@property
def host(self) -> str:
request_host = quart.request.host
if ":" in request_host:
domain, port = request_host.split(":")
# Could be an IPv6 address, so need to check whether port is a valid integer
if port.isdigit():
return domain
return request_host
def only_user_releases(self, releases: Sequence[models.Release]) -> list[models.Release]:
return util.user_releases(self.uid, releases)
async def redirect(
self, route: CommitterRouteHandler[R], success: str | None = None, error: str | None = None, **kwargs: Any
) -> response.Response:
"""Redirect to a route with a success or error message."""
if success is not None:
await quart.flash(success, "success")
elif error is not None:
await quart.flash(error, "error")
return quart.redirect(util.as_url(route, **kwargs))
async def release(
self,
project_name: str,
version_name: str,
phase: models.ReleasePhase | db.NotSet | None = db.NOT_SET,
data: db.Session | None = None,
with_committee: bool = False,
with_packages: bool = False,
with_project: bool = True,
with_tasks: bool = False,
) -> models.Release:
# We reuse db.NOT_SET as an entirely different sentinel
# TODO: We probably shouldn't do that, or should make it clearer
if phase is None:
phase_value = db.NOT_SET
elif phase is db.NOT_SET:
phase_value = models.ReleasePhase.RELEASE_CANDIDATE_DRAFT
else:
phase_value = phase
release_name = models.release_name(project_name, version_name)
if data is None:
async with db.session() as data:
release = await data.release(
name=release_name,
phase=phase_value,
_committee=with_committee,
_project=with_project,
_tasks=with_tasks,
).demand(base.ASFQuartException("Release does not exist", errorcode=404))
else:
release = await data.release(
name=release_name,
phase=phase_value,
_committee=with_committee,
_project=with_project,
_tasks=with_tasks,
).demand(base.ASFQuartException("Release does not exist", errorcode=404))
return release
@property
async def user_candidate_drafts(self) -> list[models.Release]:
return await user.candidate_drafts(self.uid, user_projects=self._projects)
# @property
# async def user_committees(self) -> list[models.Committee]:
# return ...
@property
async def user_projects(self) -> list[models.Project]:
if self._projects is None:
self._projects = await user.projects(self.uid)
return self._projects
@property
async def user_releases(self) -> list[models.Release]:
return await user.releases(self.uid)
class FlashError(RuntimeError): ...
class MicrosecondsFormatter(logging.Formatter):
# Answers on a postcard if you know why Python decided to use a comma by default
default_msec_format = "%s.%03d"
# Setup a dedicated logger for route performance metrics
# NOTE: This code block must come after AsyncFileHandler and MicrosecondsFormatter
route_logger: Final = logging.getLogger("route.performance")
# Use custom formatter that properly includes microseconds
# TODO: Is this actually UTC?
route_logger_handler: Final[AsyncFileHandler] = AsyncFileHandler("route-performance.log")
route_logger_handler.setFormatter(MicrosecondsFormatter("%(asctime)s - %(message)s"))
route_logger.addHandler(route_logger_handler)
route_logger.setLevel(logging.INFO)
# If we don't set propagate to False then it logs to the term as well
route_logger.propagate = False
# This is the type of functions to which we apply @app_route
# In other words, functions which accept no session
class RouteHandler(Protocol[R]):
"""Protocol for @app_route decorated functions."""
__name__: str
__doc__: str | None
def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[R]: ...
def app_route(
path: str, methods: list[str] | None = None, endpoint: str | None = None, measure_performance: bool = True
) -> Callable:
"""Register a route with the Flask app with built-in performance logging."""
def decorator(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Awaitable[T]]:
# First apply our performance measuring decorator
if _MEASURE_PERFORMANCE and measure_performance:
measured_func = app_route_performance_measure(path, methods)(f)
else:
measured_func = f
# Then apply the original route decorator
return asfquart.APP.route(path, methods=methods, endpoint=endpoint)(measured_func)
return decorator
def app_route_performance_measure(route_path: str, http_methods: list[str] | None = None) -> Callable:
"""Decorator that measures and logs route performance with path and method information."""
# def format_time(seconds: float) -> str:
# """Format time in appropriate units (µs or ms)."""
# microseconds = seconds * 1_000_000
# if microseconds < 1000:
# return f"{microseconds:.2f} µs"
# else:
# milliseconds = microseconds / 1000
# return f"{milliseconds:.2f} ms"
def decorator(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Awaitable[T]]:
@functools.wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# This wrapper is based on an outstanding idea by Mostafa Farzán
# Farzán realised that we can step the event loop manually
# That way, we can also divide it into synchronous and asynchronous parts
# The synchronous part is done using coro.send(None)
# The asynchronous part is done using asyncio.sleep(0)
# We use two methods for measuring the async part, and take the largest
# This performance measurement adds a bit of overhead, about 10-20ms
# Therefore it should be avoided in production, or made more efficient
# We could perhaps use for a small portion of requests
blocking_time = 0.0
async_time = 0.0
loop_time = 0.0
total_start = time.perf_counter()
coro = f(*args, **kwargs)
try:
while True:
# Measure the synchronous part
sync_start = time.perf_counter()
future = coro.send(None)
sync_end = time.perf_counter()
blocking_time += sync_end - sync_start
# Measure the asynchronous part in two different ways
loop = asyncio.get_running_loop()
wait_start = time.perf_counter()
loop_start = loop.time()
if future is not None:
while not future.done():
await asyncio.sleep(0)
wait_end = time.perf_counter()
loop_end = loop.time()
async_time += wait_end - wait_start
loop_time += loop_end - loop_start
# Raise exception if any
# future.result()
except StopIteration as e:
total_end = time.perf_counter()
total_time = total_end - total_start
methods_str = ",".join(http_methods) if http_methods else "GET"
nonblocking_time = max(async_time, loop_time)
# If async time is more than 10% different from loop time, log it
delta_symbol = "="
nonblocking_delta = abs(async_time - loop_time)
# Must check that nonblocking_time is not 0 to avoid division by zero
if nonblocking_time and ((nonblocking_delta / nonblocking_time) > 0.1):
delta_symbol = "!"
route_logger.info(
"%s %s %s %s %s %s %s",
methods_str,
route_path,
f.__name__,
delta_symbol,
int(blocking_time * 1000),
int(nonblocking_time * 1000),
int(total_time * 1000),
)
return e.value
return wrapper
return decorator
# This decorator is an adaptor between @committer_get and @app_route functions
def committer(
path: str, methods: list[str] | None = None, measure_performance: bool = True
) -> Callable[[CommitterRouteHandler[R]], RouteHandler[R]]:
"""Decorator for committer GET routes that provides an enhanced session object."""
def decorator(func: CommitterRouteHandler[R]) -> RouteHandler[R]:
async def wrapper(*args: Any, **kwargs: Any) -> R:
web_session = await session.read()
if web_session is None:
_authentication_failed()
enhanced_session = CommitterSession(web_session)
return await func(enhanced_session, *args, **kwargs)
# Generate a unique endpoint name
endpoint = func.__module__ + "_" + func.__name__
# Set the name before applying decorators
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
wrapper.__annotations__["endpoint"] = endpoint
# Apply decorators in reverse order
decorated = auth.require(auth.Requirements.committer)(wrapper)
decorated = app_route(
path, methods=methods or ["GET"], endpoint=endpoint, measure_performance=measure_performance
)(decorated)
return decorated
return decorator
async def get_form(request: quart.Request) -> datastructures.MultiDict:
# The request.form() method in Quart calls a synchronous tempfile method
# It calls quart.wrappers.request.form _load_form_data
# Which calls quart.formparser parse and parse_func and parser.parse
# Which calls _write which calls tempfile, which is synchronous
# It's getting a tempfile back from some prior call
# We can't just make blockbuster ignore the call because then it ignores it everywhere
app = asfquart.APP
if app is ...:
raise RuntimeError("APP is not set")
# Or quart.current_app?
blockbuster = app.extensions.get("blockbuster")
# Turn blockbuster off
if blockbuster is not None:
blockbuster.deactivate()
form = await request.form
# Turn blockbuster on
if blockbuster is not None:
blockbuster.activate()
return form
def public(
path: str, methods: list[str] | None = None, measure_performance: bool = True
) -> Callable[[RouteHandler[R]], RouteHandler[R]]:
"""Decorator for public GET routes that provides an enhanced session object."""
def decorator(func: RouteHandler[R]) -> RouteHandler[R]:
async def wrapper(*args: Any, **kwargs: Any) -> R:
return await func(*args, **kwargs)
# Generate a unique endpoint name
endpoint = func.__module__ + "_" + func.__name__
# Set the name before applying decorators
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
wrapper.__annotations__["endpoint"] = endpoint
# Apply decorators in reverse order
decorated = app_route(
path, methods=methods or ["GET"], endpoint=endpoint, measure_performance=measure_performance
)(wrapper)
return decorated
return decorator
def _authentication_failed() -> NoReturn:
"""Handle authentication failure with an exception."""
# NOTE: This is a separate function to fix a problem with analysis flow in mypy
raise base.ASFQuartException("Not authenticated", errorcode=401)