atr/util.py (443 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import asyncio import binascii import contextlib import dataclasses import datetime import hashlib import logging import pathlib import re import shutil import tarfile import tempfile import uuid import zipfile from collections.abc import AsyncGenerator, Callable, Generator, ItemsView, Mapping, Sequence from typing import Annotated, Any, Final, TypeVar import aiofiles.os import asfquart import asfquart.base as base import asfquart.session as session import jinja2 import pydantic import pydantic_core import quart import quart_wtf import quart_wtf.typing import wtforms # NOTE: The atr.db module imports this module # Therefore, this module must not import atr.db import atr.config as config import atr.db.models as models import atr.user as user F = TypeVar("F", bound="QuartFormTyped") T = TypeVar("T") VT = TypeVar("VT") _LOGGER: Final = logging.getLogger(__name__) class DictRootModel(pydantic.RootModel[dict[str, VT]]): def __iter__(self) -> Generator[tuple[str, VT]]: yield from self.root.items() def items(self) -> ItemsView[str, VT]: return self.root.items() def get(self, key: str) -> VT | None: return self.root.get(key) def __len__(self) -> int: return len(self.root) # from https://github.com/pydantic/pydantic/discussions/8755#discussioncomment-8417979 @dataclasses.dataclass class DictToList: key: str def __get_pydantic_core_schema__( self, source_type: Any, handler: pydantic.GetCoreSchemaHandler, ) -> pydantic_core.CoreSchema: adapter = _get_dict_to_list_inner_type_adapter(source_type, self.key) return pydantic_core.core_schema.no_info_before_validator_function( _get_dict_to_list_validator(adapter, self.key), handler(source_type), ) @dataclasses.dataclass class FileStat: path: str modified: int size: int permissions: int is_file: bool is_dir: bool class QuartFormTyped(quart_wtf.QuartForm): """Quart form with type annotations.""" @classmethod async def create_form( cls: type[F], formdata: object | quart_wtf.typing.FormData = quart_wtf.form._Auto, obj: Any | None = None, prefix: str = "", data: dict | None = None, meta: dict | None = None, **kwargs: dict[str, Any], ) -> F: """Create a form instance with typing.""" form = await super().create_form(formdata, obj, prefix, data, meta, **kwargs) if not isinstance(form, cls): raise TypeError(f"Form is not of type {cls.__name__}") return form async def archive_listing(file_path: pathlib.Path) -> list[str] | None: """Attempt to list contents of supported archive files.""" if not await aiofiles.os.path.isfile(file_path): return None with contextlib.suppress(Exception): if file_path.name.endswith((".tar.gz", ".tgz")): def _read_tar() -> list[str] | None: with contextlib.suppress(tarfile.ReadError, EOFError, ValueError, OSError): with tarfile.open(file_path, mode="r:*") as tf: # TODO: Skip metadata files return sorted(tf.getnames()) return None return await asyncio.to_thread(_read_tar) elif file_path.name.endswith(".zip"): def _read_zip() -> list[str] | None: with contextlib.suppress(zipfile.BadZipFile, EOFError, ValueError, OSError): with zipfile.ZipFile(file_path, "r") as zf: return sorted(zf.namelist()) return None return await asyncio.to_thread(_read_zip) return None def as_url(func: Callable, **kwargs: Any) -> str: """Return the URL for a function.""" if isinstance(func, jinja2.runtime.Undefined): _LOGGER.exception("Undefined route in the calling template") raise RuntimeError("Undefined route in the calling template") try: annotations = func.__annotations__ except AttributeError as e: _LOGGER.error(f"Cannot get annotations for {func} (type: {type(func)})") raise RuntimeError(f"Cannot get annotations for {func} (type: {type(func)})") from e return quart.url_for(annotations["endpoint"], **kwargs) @contextlib.asynccontextmanager async def async_temporary_directory( suffix: str | None = None, prefix: str | None = None, dir: str | pathlib.Path | None = None ) -> AsyncGenerator[pathlib.Path]: """Create an async temporary directory similar to tempfile.TemporaryDirectory.""" temp_dir_path: str = await asyncio.to_thread(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir) try: yield pathlib.Path(temp_dir_path) finally: await asyncio.to_thread(shutil.rmtree, temp_dir_path, ignore_errors=True) def compute_sha3_256(file_data: bytes) -> str: """Compute SHA3-256 hash of file data.""" return hashlib.sha3_256(file_data).hexdigest() async def compute_sha512(file_path: pathlib.Path) -> str: """Compute SHA-512 hash of a file.""" sha512 = hashlib.sha512() async with aiofiles.open(file_path, "rb") as f: while chunk := await f.read(4096): sha512.update(chunk) return sha512.hexdigest() async def content_list( phase_subdir: pathlib.Path, project_name: str, version_name: str, revision_name: str | None = None ) -> AsyncGenerator[FileStat]: """List all the files in the given path.""" base_path = phase_subdir / project_name / version_name if phase_subdir.name in {"release-candidate-draft", "release-preview"}: if revision_name is None: raise ValueError("A revision name is required for release candidate draft or preview content listing") if revision_name: base_path = base_path / revision_name async for path in paths_recursive(base_path): stat = await aiofiles.os.stat(base_path / path) yield FileStat( path=str(path), modified=int(stat.st_mtime), size=stat.st_size, permissions=stat.st_mode, is_file=bool(stat.st_mode & 0o0100000), is_dir=bool(stat.st_mode & 0o040000), ) async def create_hard_link_clone(source_dir: pathlib.Path, dest_dir: pathlib.Path) -> None: """Recursively create a clone of source_dir in dest_dir using hard links for files.""" # TODO: We're currently using cp -al instead # Ensure source exists and is a directory if not await aiofiles.os.path.isdir(source_dir): raise ValueError(f"Source path is not a directory or does not exist: {source_dir}") # Create destination directory await aiofiles.os.makedirs(dest_dir, exist_ok=False) async def _clone_recursive(current_source: pathlib.Path, current_dest: pathlib.Path) -> None: for entry in await aiofiles.os.scandir(current_source): source_entry_path = current_source / entry.name dest_entry_path = current_dest / entry.name try: if entry.is_dir(): await aiofiles.os.makedirs(dest_entry_path, exist_ok=True) await _clone_recursive(source_entry_path, dest_entry_path) elif entry.is_file(): await aiofiles.os.link(source_entry_path, dest_entry_path) # Ignore other types like symlinks for now except OSError as e: _LOGGER.error(f"Error cloning {source_entry_path} to {dest_entry_path}: {e}") raise await _clone_recursive(source_dir, dest_dir) async def file_sha3(path: str) -> str: """Compute SHA3-256 hash of a file.""" sha3 = hashlib.sha3_256() async with aiofiles.open(path, "rb") as f: while chunk := await f.read(4096): sha3.update(chunk) return sha3.hexdigest() def format_datetime(dt_obj: datetime.datetime | int) -> str: """Format a datetime object or Unix timestamp into a human readable datetime string.""" # Integers are unix timestamps if isinstance(dt_obj, int): dt_obj = datetime.datetime.fromtimestamp(dt_obj, tz=datetime.UTC) # Ensure UTC native timezone awareness if dt_obj.tzinfo is None: dt_obj = dt_obj.replace(tzinfo=datetime.UTC) else: # Convert to UTC if not already dt_obj = dt_obj.astimezone(datetime.UTC) return dt_obj.strftime("%Y-%m-%d %H:%M:%S") def format_file_size(size_in_bytes: int) -> str: """Format a file size with appropriate units and comma-separated digits.""" # Format the raw bytes with commas formatted_bytes = f"{size_in_bytes:,}" # Calculate the appropriate unit if size_in_bytes >= 1_000_000_000: size_in_gb = size_in_bytes // 1_000_000_000 return f"{size_in_gb:,} GB ({formatted_bytes} bytes)" elif size_in_bytes >= 1_000_000: size_in_mb = size_in_bytes // 1_000_000 return f"{size_in_mb:,} MB ({formatted_bytes} bytes)" elif size_in_bytes >= 1_000: size_in_kb = size_in_bytes // 1_000 return f"{size_in_kb:,} KB ({formatted_bytes} bytes)" else: return f"{formatted_bytes} bytes" def format_permissions(mode: int) -> str: """Format Unix file permissions in ls -l style.""" # File type if mode & 0o040000: # Directory perms = "d" elif mode & 0o0100000: # Regular file perms = "-" elif mode & 0o020000: # Character special perms = "c" elif mode & 0o060000: # Block special perms = "b" elif mode & 0o010000: # FIFO perms = "p" elif mode & 0o0140000: # Socket perms = "s" else: perms = "?" # Owner permissions perms += "r" if mode & 0o400 else "-" perms += "w" if mode & 0o200 else "-" perms += "x" if mode & 0o100 else "-" # Group permissions perms += "r" if mode & 0o040 else "-" perms += "w" if mode & 0o020 else "-" perms += "x" if mode & 0o010 else "-" # Others permissions perms += "r" if mode & 0o004 else "-" perms += "w" if mode & 0o002 else "-" perms += "x" if mode & 0o001 else "-" return perms async def get_asf_id_or_die() -> str: web_session = await session.read() if web_session is None or web_session.uid is None: raise base.ASFQuartException("Not authenticated", errorcode=401) return web_session.uid def get_finished_dir() -> pathlib.Path: return pathlib.Path(config.get().FINISHED_STORAGE_DIR) async def get_release_stats(release: models.Release) -> tuple[int, int, str]: """Calculate file count, total byte size, and formatted size for a release.""" base_dir = release_directory(release) count = 0 total_bytes = 0 try: async for rel_path in paths_recursive(base_dir): full_path = base_dir / rel_path if await aiofiles.os.path.isfile(full_path): try: size = await aiofiles.os.path.getsize(full_path) count += 1 total_bytes += size except OSError: ... except FileNotFoundError: ... formatted_size = format_file_size(total_bytes) return count, total_bytes, formatted_size def get_unfinished_dir() -> pathlib.Path: return pathlib.Path(config.get().UNFINISHED_STORAGE_DIR) def is_user_viewing_as_admin(uid: str | None) -> bool: """Check whether a user is currently viewing the site with active admin privileges.""" if not user.is_admin(uid): return False try: app = asfquart.APP if not hasattr(app, "app_id") or not isinstance(app.app_id, str): _LOGGER.error("Cannot get valid app_id to read session for admin view check") return True cookie_id = app.app_id session_dict = quart.session.get(cookie_id, {}) is_downgraded = session_dict.get("downgrade_admin_to_user", False) return not is_downgraded except Exception: _LOGGER.exception(f"Error checking admin downgrade session status for {uid}") return True async def number_of_release_files(release: models.Release) -> int: """Return the number of files in a release.""" path_project = release.project.name path_version = release.version path_revision = release.revision or "force-error" match release.phase: case models.ReleasePhase.RELEASE_CANDIDATE_DRAFT: path = get_unfinished_dir() / path_project / path_version / path_revision case models.ReleasePhase.RELEASE_CANDIDATE: path = get_unfinished_dir() / path_project / path_version / path_revision case models.ReleasePhase.RELEASE_PREVIEW: path = get_unfinished_dir() / path_project / path_version / path_revision case models.ReleasePhase.RELEASE: path = get_finished_dir() / path_project / path_version case _: raise ValueError(f"Unknown release phase: {release.phase}") count = 0 async for _ in paths_recursive(path): count += 1 return count async def paths_recursive(base_path: pathlib.Path) -> AsyncGenerator[pathlib.Path]: """Yield all file paths recursively within a base path, relative to the base path.""" try: abs_base_path = await asyncio.to_thread(base_path.resolve) for entry in await aiofiles.os.scandir(abs_base_path): entry_path = pathlib.Path(entry.path) relative_path = entry_path.relative_to(abs_base_path) if entry.is_file(): yield relative_path elif entry.is_dir(): async for sub_path in paths_recursive(entry_path): yield relative_path / sub_path except FileNotFoundError: return def permitted_recipients(asf_uid: str) -> list[str]: test_list = "user-tests" return [ # f"dev@{committee.name}.apache.org", # f"private@{committee.name}.apache.org", f"{test_list}@tooling.apache.org", f"{asf_uid}@apache.org", ] async def read_file_for_viewer(full_path: pathlib.Path, max_size: int) -> tuple[str | None, bool, bool, str | None]: """Read file content for viewer.""" content: str | None = None is_text = False is_truncated = False error_message: str | None = None try: if not await aiofiles.os.path.exists(full_path): return None, False, False, "File does not exist" if not await aiofiles.os.path.isfile(full_path): return None, False, False, "Path is not a file" file_size = await aiofiles.os.path.getsize(full_path) read_size = min(file_size, max_size) if file_size > max_size: is_truncated = True if file_size == 0: is_text = True content = "(Empty file)" raw_content = b"" else: async with aiofiles.open(full_path, "rb") as f: raw_content = await f.read(read_size) if file_size > 0: try: if b"\x00" in raw_content: raise UnicodeDecodeError("utf-8", b"", 0, 1, "Null byte found") content = raw_content.decode("utf-8") is_text = True except UnicodeDecodeError: is_text = False content = _generate_hexdump(raw_content) except Exception as e: error_message = f"An error occurred reading the file: {e!s}" return content, is_text, is_truncated, error_message def release_directory(release: models.Release) -> pathlib.Path: """Return the absolute path to the directory containing the active files for a given release phase.""" if release.revision is None: return release_directory_base(release) return release_directory_base(release) / release.revision def release_directory_base(release: models.Release) -> pathlib.Path: """Determine the filesystem directory for a given release based on its phase.""" phase = release.phase try: project_name, version_name = release.name.rsplit("-", 1) except ValueError: raise base.ASFQuartException(f"Invalid release name format '{release.name}'", 500) base_dir: pathlib.Path | None = None match phase: case models.ReleasePhase.RELEASE_CANDIDATE_DRAFT: base_dir = get_unfinished_dir() case models.ReleasePhase.RELEASE_CANDIDATE: base_dir = get_unfinished_dir() case models.ReleasePhase.RELEASE_PREVIEW: base_dir = get_unfinished_dir() case models.ReleasePhase.RELEASE: base_dir = get_finished_dir() # NOTE: Do NOT add "case _" here return base_dir / project_name / version_name def unwrap(value: T | None, error_message: str = "unexpected None when unwrapping value") -> T: """ Will unwrap the given value or raise a ValueError if it is None :param value: the optional value to unwrap :param error_message: the error message when failing to unwrap :return: the value or a ValueError if it is None """ if value is None: raise ValueError(error_message) else: return value def unwrap_type(value: T | None, t: type[T], error_message: str = "unexpected None when unwrapping value") -> T: """ Will unwrap the given value or raise a TypeError if it is not of the expected type :param value: the optional value to unwrap :param t: the expected type of the value :param error_message: the error message when failing to unwrap """ if value is None: raise ValueError(error_message) if not isinstance(value, t): raise ValueError(f"Expected {t}, got {type(value)}") return value async def update_atomic_symlink(link_path: pathlib.Path, target_path: pathlib.Path | str) -> None: """Atomically update or create a symbolic link at link_path pointing to target_path.""" target_str = str(target_path) # Generate a temporary path name for the new link link_dir = link_path.parent temp_link_path = link_dir / f".{link_path.name}.{uuid.uuid4()}.tmp" try: await aiofiles.os.symlink(target_str, temp_link_path) # Atomically rename the temporary link to the final link path # This overwrites link_path if it exists await aiofiles.os.rename(temp_link_path, link_path) _LOGGER.info(f"Atomically updated symlink {link_path} -> {target_str}") except Exception as e: # Don't bother with _LOGGER.exception here _LOGGER.error(f"Failed to update atomic symlink {link_path} -> {target_str}: {e}") # Clean up temporary link if rename failed try: await aiofiles.os.remove(temp_link_path) except FileNotFoundError: # TODO: Use with contextlib.suppress(FileNotFoundError) for these sorts of blocks? pass raise def user_releases(asf_uid: str, releases: Sequence[models.Release]) -> list[models.Release]: """Return a list of releases for which the user is a committee member or committer.""" # TODO: This should probably be a session method instead user_releases = [] for release in releases: if release.committee is None: continue if (asf_uid in release.committee.committee_members) or (asf_uid in release.committee.committers): user_releases.append(release) return user_releases def validate_as_type(value: Any, t: type[T]) -> T: """Validate the given value as the given type.""" if not isinstance(value, t): raise ValueError(f"Expected {t}, got {type(value)}") return value def validate_vote_duration(form: wtforms.Form, field: wtforms.IntegerField) -> None: """Checks if the value is 0 or between 72 and 144.""" if not ((field.data == 0) or (72 <= field.data <= 144)): raise wtforms.validators.ValidationError("Minimum voting period must be 0 hours, or between 72 and 144 hours") def version_name_error(version_name: str) -> str | None: """Check if the given version name is valid.""" if version_name == "": return "Must not be empty" if version_name.lower() == "version": return "Must not be 'version'" if not re.match(r"^[a-zA-Z0-9]", version_name): return "Must start with a letter or number" if not re.search(r"[a-zA-Z0-9]$", version_name): return "Must end with a letter or number" if re.search(r"[+.-]{2,}", version_name): return "Must not contain multiple consecutive plus, full stop, or hyphen" if not re.match(r"^[a-zA-Z0-9+.-]+$", version_name): return "Must contain only letters, numbers, plus, full stop, or hyphen" return None def _generate_hexdump(data: bytes) -> str: """Generate a formatted hexdump string from bytes.""" hex_lines = [] for i in range(0, len(data), 16): chunk = data[i : i + 16] hex_part = binascii.hexlify(chunk).decode("ascii") hex_part = hex_part.ljust(32) hex_part_spaced = " ".join(hex_part[j : j + 2] for j in range(0, len(hex_part), 2)) ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk) line_num = f"{i:08x}" hex_lines.append(f"{line_num} {hex_part_spaced} |{ascii_part}|") return "\n".join(hex_lines) def _get_dict_to_list_inner_type_adapter(source_type: Any, key: str) -> pydantic.TypeAdapter[dict[Any, Any]]: root_adapter = pydantic.TypeAdapter(source_type) schema = root_adapter.core_schema # support further nesting of model classes if schema["type"] == "definitions": schema = schema["schema"] assert schema["type"] == "list" assert (item_schema := schema["items_schema"]) assert item_schema["type"] == "model" assert (cls := item_schema["cls"]) # noqa: RUF018 fields = cls.model_fields assert (key_field := fields.get(key)) # noqa: RUF018 assert (other_fields := {k: v for k, v in fields.items() if k != key}) # noqa: RUF018 model_name = f"{cls.__name__}Inner" # Create proper field definitions for create_model inner_model = pydantic.create_model(model_name, **{k: (v.annotation, v) for k, v in other_fields.items()}) # type: ignore return pydantic.TypeAdapter(dict[Annotated[str, key_field], inner_model]) # type: ignore def _get_dict_to_list_validator(inner_adapter: pydantic.TypeAdapter[dict[Any, Any]], key: str) -> Any: def validator(val: Any) -> Any: import pydantic.fields as fields if isinstance(val, dict): validated = inner_adapter.validate_python(val) # need to get the alias of the field in the nested model # as this will be fed into the actual model class def get_alias(field_name: str, field_infos: Mapping[str, fields.FieldInfo]) -> Any: field = field_infos[field_name] return field.alias if field.alias else field_name return [ {key: k, **{get_alias(f, v.model_fields): getattr(v, f) for f in v.model_fields}} for k, v in validated.items() ] return val return validator