# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import asyncio
import binascii
import contextlib
import dataclasses
import datetime
import hashlib
import logging
import pathlib
import re
import shutil
import tarfile
import tempfile
import uuid
import zipfile
from collections.abc import AsyncGenerator, Callable, Generator, ItemsView, Mapping, Sequence
from typing import Annotated, Any, Final, TypeVar

import aiofiles.os
import asfquart
import asfquart.base as base
import asfquart.session as session
import jinja2
import pydantic
import pydantic_core
import quart
import quart_wtf
import quart_wtf.typing
import wtforms

# NOTE: The atr.db module imports this module
# Therefore, this module must not import atr.db
import atr.config as config
import atr.db.models as models
import atr.user as user

F = TypeVar("F", bound="QuartFormTyped")
T = TypeVar("T")
VT = TypeVar("VT")

_LOGGER: Final = logging.getLogger(__name__)


class DictRootModel(pydantic.RootModel[dict[str, VT]]):
    def __iter__(self) -> Generator[tuple[str, VT]]:
        yield from self.root.items()

    def items(self) -> ItemsView[str, VT]:
        return self.root.items()

    def get(self, key: str) -> VT | None:
        return self.root.get(key)

    def __len__(self) -> int:
        return len(self.root)


# from https://github.com/pydantic/pydantic/discussions/8755#discussioncomment-8417979
@dataclasses.dataclass
class DictToList:
    key: str

    def __get_pydantic_core_schema__(
        self,
        source_type: Any,
        handler: pydantic.GetCoreSchemaHandler,
    ) -> pydantic_core.CoreSchema:
        adapter = _get_dict_to_list_inner_type_adapter(source_type, self.key)

        return pydantic_core.core_schema.no_info_before_validator_function(
            _get_dict_to_list_validator(adapter, self.key),
            handler(source_type),
        )


@dataclasses.dataclass
class FileStat:
    path: str
    modified: int
    size: int
    permissions: int
    is_file: bool
    is_dir: bool


class QuartFormTyped(quart_wtf.QuartForm):
    """Quart form with type annotations."""

    @classmethod
    async def create_form(
        cls: type[F],
        formdata: object | quart_wtf.typing.FormData = quart_wtf.form._Auto,
        obj: Any | None = None,
        prefix: str = "",
        data: dict | None = None,
        meta: dict | None = None,
        **kwargs: dict[str, Any],
    ) -> F:
        """Create a form instance with typing."""
        form = await super().create_form(formdata, obj, prefix, data, meta, **kwargs)
        if not isinstance(form, cls):
            raise TypeError(f"Form is not of type {cls.__name__}")
        return form


async def archive_listing(file_path: pathlib.Path) -> list[str] | None:
    """Attempt to list contents of supported archive files."""
    if not await aiofiles.os.path.isfile(file_path):
        return None

    with contextlib.suppress(Exception):
        if file_path.name.endswith((".tar.gz", ".tgz")):

            def _read_tar() -> list[str] | None:
                with contextlib.suppress(tarfile.ReadError, EOFError, ValueError, OSError):
                    with tarfile.open(file_path, mode="r:*") as tf:
                        # TODO: Skip metadata files
                        return sorted(tf.getnames())
                return None

            return await asyncio.to_thread(_read_tar)

        elif file_path.name.endswith(".zip"):

            def _read_zip() -> list[str] | None:
                with contextlib.suppress(zipfile.BadZipFile, EOFError, ValueError, OSError):
                    with zipfile.ZipFile(file_path, "r") as zf:
                        return sorted(zf.namelist())
                return None

            return await asyncio.to_thread(_read_zip)

    return None


def as_url(func: Callable, **kwargs: Any) -> str:
    """Return the URL for a function."""
    if isinstance(func, jinja2.runtime.Undefined):
        _LOGGER.exception("Undefined route in the calling template")
        raise RuntimeError("Undefined route in the calling template")
    try:
        annotations = func.__annotations__
    except AttributeError as e:
        _LOGGER.error(f"Cannot get annotations for {func} (type: {type(func)})")
        raise RuntimeError(f"Cannot get annotations for {func} (type: {type(func)})") from e
    return quart.url_for(annotations["endpoint"], **kwargs)


@contextlib.asynccontextmanager
async def async_temporary_directory(
    suffix: str | None = None, prefix: str | None = None, dir: str | pathlib.Path | None = None
) -> AsyncGenerator[pathlib.Path]:
    """Create an async temporary directory similar to tempfile.TemporaryDirectory."""
    temp_dir_path: str = await asyncio.to_thread(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir)
    try:
        yield pathlib.Path(temp_dir_path)
    finally:
        await asyncio.to_thread(shutil.rmtree, temp_dir_path, ignore_errors=True)


def compute_sha3_256(file_data: bytes) -> str:
    """Compute SHA3-256 hash of file data."""
    return hashlib.sha3_256(file_data).hexdigest()


async def compute_sha512(file_path: pathlib.Path) -> str:
    """Compute SHA-512 hash of a file."""
    sha512 = hashlib.sha512()
    async with aiofiles.open(file_path, "rb") as f:
        while chunk := await f.read(4096):
            sha512.update(chunk)
    return sha512.hexdigest()


async def content_list(
    phase_subdir: pathlib.Path, project_name: str, version_name: str, revision_name: str | None = None
) -> AsyncGenerator[FileStat]:
    """List all the files in the given path."""
    base_path = phase_subdir / project_name / version_name
    if phase_subdir.name in {"release-candidate-draft", "release-preview"}:
        if revision_name is None:
            raise ValueError("A revision name is required for release candidate draft or preview content listing")
    if revision_name:
        base_path = base_path / revision_name
    async for path in paths_recursive(base_path):
        stat = await aiofiles.os.stat(base_path / path)
        yield FileStat(
            path=str(path),
            modified=int(stat.st_mtime),
            size=stat.st_size,
            permissions=stat.st_mode,
            is_file=bool(stat.st_mode & 0o0100000),
            is_dir=bool(stat.st_mode & 0o040000),
        )


async def create_hard_link_clone(source_dir: pathlib.Path, dest_dir: pathlib.Path) -> None:
    """Recursively create a clone of source_dir in dest_dir using hard links for files."""
    # TODO: We're currently using cp -al instead
    # Ensure source exists and is a directory
    if not await aiofiles.os.path.isdir(source_dir):
        raise ValueError(f"Source path is not a directory or does not exist: {source_dir}")

    # Create destination directory
    await aiofiles.os.makedirs(dest_dir, exist_ok=False)

    async def _clone_recursive(current_source: pathlib.Path, current_dest: pathlib.Path) -> None:
        for entry in await aiofiles.os.scandir(current_source):
            source_entry_path = current_source / entry.name
            dest_entry_path = current_dest / entry.name

            try:
                if entry.is_dir():
                    await aiofiles.os.makedirs(dest_entry_path, exist_ok=True)
                    await _clone_recursive(source_entry_path, dest_entry_path)
                elif entry.is_file():
                    await aiofiles.os.link(source_entry_path, dest_entry_path)
                # Ignore other types like symlinks for now
            except OSError as e:
                _LOGGER.error(f"Error cloning {source_entry_path} to {dest_entry_path}: {e}")
                raise

    await _clone_recursive(source_dir, dest_dir)


async def file_sha3(path: str) -> str:
    """Compute SHA3-256 hash of a file."""
    sha3 = hashlib.sha3_256()
    async with aiofiles.open(path, "rb") as f:
        while chunk := await f.read(4096):
            sha3.update(chunk)
    return sha3.hexdigest()


def format_datetime(dt_obj: datetime.datetime | int) -> str:
    """Format a datetime object or Unix timestamp into a human readable datetime string."""
    # Integers are unix timestamps
    if isinstance(dt_obj, int):
        dt_obj = datetime.datetime.fromtimestamp(dt_obj, tz=datetime.UTC)

    # Ensure UTC native timezone awareness
    if dt_obj.tzinfo is None:
        dt_obj = dt_obj.replace(tzinfo=datetime.UTC)
    else:
        # Convert to UTC if not already
        dt_obj = dt_obj.astimezone(datetime.UTC)

    return dt_obj.strftime("%Y-%m-%d %H:%M:%S")


def format_file_size(size_in_bytes: int) -> str:
    """Format a file size with appropriate units and comma-separated digits."""
    # Format the raw bytes with commas
    formatted_bytes = f"{size_in_bytes:,}"

    # Calculate the appropriate unit
    if size_in_bytes >= 1_000_000_000:
        size_in_gb = size_in_bytes // 1_000_000_000
        return f"{size_in_gb:,} GB ({formatted_bytes} bytes)"
    elif size_in_bytes >= 1_000_000:
        size_in_mb = size_in_bytes // 1_000_000
        return f"{size_in_mb:,} MB ({formatted_bytes} bytes)"
    elif size_in_bytes >= 1_000:
        size_in_kb = size_in_bytes // 1_000
        return f"{size_in_kb:,} KB ({formatted_bytes} bytes)"
    else:
        return f"{formatted_bytes} bytes"


def format_permissions(mode: int) -> str:
    """Format Unix file permissions in ls -l style."""
    # File type
    if mode & 0o040000:
        # Directory
        perms = "d"
    elif mode & 0o0100000:
        # Regular file
        perms = "-"
    elif mode & 0o020000:
        # Character special
        perms = "c"
    elif mode & 0o060000:
        # Block special
        perms = "b"
    elif mode & 0o010000:
        # FIFO
        perms = "p"
    elif mode & 0o0140000:
        # Socket
        perms = "s"
    else:
        perms = "?"

    # Owner permissions
    perms += "r" if mode & 0o400 else "-"
    perms += "w" if mode & 0o200 else "-"
    perms += "x" if mode & 0o100 else "-"

    # Group permissions
    perms += "r" if mode & 0o040 else "-"
    perms += "w" if mode & 0o020 else "-"
    perms += "x" if mode & 0o010 else "-"

    # Others permissions
    perms += "r" if mode & 0o004 else "-"
    perms += "w" if mode & 0o002 else "-"
    perms += "x" if mode & 0o001 else "-"

    return perms


async def get_asf_id_or_die() -> str:
    web_session = await session.read()
    if web_session is None or web_session.uid is None:
        raise base.ASFQuartException("Not authenticated", errorcode=401)
    return web_session.uid


def get_finished_dir() -> pathlib.Path:
    return pathlib.Path(config.get().FINISHED_STORAGE_DIR)


async def get_release_stats(release: models.Release) -> tuple[int, int, str]:
    """Calculate file count, total byte size, and formatted size for a release."""
    base_dir = release_directory(release)
    count = 0
    total_bytes = 0
    try:
        async for rel_path in paths_recursive(base_dir):
            full_path = base_dir / rel_path
            if await aiofiles.os.path.isfile(full_path):
                try:
                    size = await aiofiles.os.path.getsize(full_path)
                    count += 1
                    total_bytes += size
                except OSError:
                    ...
    except FileNotFoundError:
        ...

    formatted_size = format_file_size(total_bytes)
    return count, total_bytes, formatted_size


def get_unfinished_dir() -> pathlib.Path:
    return pathlib.Path(config.get().UNFINISHED_STORAGE_DIR)


def is_user_viewing_as_admin(uid: str | None) -> bool:
    """Check whether a user is currently viewing the site with active admin privileges."""
    if not user.is_admin(uid):
        return False

    try:
        app = asfquart.APP
        if not hasattr(app, "app_id") or not isinstance(app.app_id, str):
            _LOGGER.error("Cannot get valid app_id to read session for admin view check")
            return True

        cookie_id = app.app_id
        session_dict = quart.session.get(cookie_id, {})
        is_downgraded = session_dict.get("downgrade_admin_to_user", False)
        return not is_downgraded
    except Exception:
        _LOGGER.exception(f"Error checking admin downgrade session status for {uid}")
        return True


async def number_of_release_files(release: models.Release) -> int:
    """Return the number of files in a release."""
    path_project = release.project.name
    path_version = release.version
    path_revision = release.revision or "force-error"
    match release.phase:
        case models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
            path = get_unfinished_dir() / path_project / path_version / path_revision
        case models.ReleasePhase.RELEASE_CANDIDATE:
            path = get_unfinished_dir() / path_project / path_version / path_revision
        case models.ReleasePhase.RELEASE_PREVIEW:
            path = get_unfinished_dir() / path_project / path_version / path_revision
        case models.ReleasePhase.RELEASE:
            path = get_finished_dir() / path_project / path_version
        case _:
            raise ValueError(f"Unknown release phase: {release.phase}")
    count = 0
    async for _ in paths_recursive(path):
        count += 1
    return count


async def paths_recursive(base_path: pathlib.Path) -> AsyncGenerator[pathlib.Path]:
    """Yield all file paths recursively within a base path, relative to the base path."""
    try:
        abs_base_path = await asyncio.to_thread(base_path.resolve)
        for entry in await aiofiles.os.scandir(abs_base_path):
            entry_path = pathlib.Path(entry.path)
            relative_path = entry_path.relative_to(abs_base_path)
            if entry.is_file():
                yield relative_path
            elif entry.is_dir():
                async for sub_path in paths_recursive(entry_path):
                    yield relative_path / sub_path
    except FileNotFoundError:
        return


def permitted_recipients(asf_uid: str) -> list[str]:
    test_list = "user-tests"
    return [
        # f"dev@{committee.name}.apache.org",
        # f"private@{committee.name}.apache.org",
        f"{test_list}@tooling.apache.org",
        f"{asf_uid}@apache.org",
    ]


async def read_file_for_viewer(full_path: pathlib.Path, max_size: int) -> tuple[str | None, bool, bool, str | None]:
    """Read file content for viewer."""
    content: str | None = None
    is_text = False
    is_truncated = False
    error_message: str | None = None

    try:
        if not await aiofiles.os.path.exists(full_path):
            return None, False, False, "File does not exist"
        if not await aiofiles.os.path.isfile(full_path):
            return None, False, False, "Path is not a file"

        file_size = await aiofiles.os.path.getsize(full_path)
        read_size = min(file_size, max_size)

        if file_size > max_size:
            is_truncated = True

        if file_size == 0:
            is_text = True
            content = "(Empty file)"
            raw_content = b""
        else:
            async with aiofiles.open(full_path, "rb") as f:
                raw_content = await f.read(read_size)

        if file_size > 0:
            try:
                if b"\x00" in raw_content:
                    raise UnicodeDecodeError("utf-8", b"", 0, 1, "Null byte found")
                content = raw_content.decode("utf-8")
                is_text = True
            except UnicodeDecodeError:
                is_text = False
                content = _generate_hexdump(raw_content)

    except Exception as e:
        error_message = f"An error occurred reading the file: {e!s}"

    return content, is_text, is_truncated, error_message


def release_directory(release: models.Release) -> pathlib.Path:
    """Return the absolute path to the directory containing the active files for a given release phase."""
    if release.revision is None:
        return release_directory_base(release)
    return release_directory_base(release) / release.revision


def release_directory_base(release: models.Release) -> pathlib.Path:
    """Determine the filesystem directory for a given release based on its phase."""
    phase = release.phase
    try:
        project_name, version_name = release.name.rsplit("-", 1)
    except ValueError:
        raise base.ASFQuartException(f"Invalid release name format '{release.name}'", 500)

    base_dir: pathlib.Path | None = None
    match phase:
        case models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
            base_dir = get_unfinished_dir()
        case models.ReleasePhase.RELEASE_CANDIDATE:
            base_dir = get_unfinished_dir()
        case models.ReleasePhase.RELEASE_PREVIEW:
            base_dir = get_unfinished_dir()
        case models.ReleasePhase.RELEASE:
            base_dir = get_finished_dir()
        # NOTE: Do NOT add "case _" here

    return base_dir / project_name / version_name


def unwrap(value: T | None, error_message: str = "unexpected None when unwrapping value") -> T:
    """
    Will unwrap the given value or raise a ValueError if it is None

    :param value: the optional value to unwrap
    :param error_message: the error message when failing to unwrap
    :return: the value or a ValueError if it is None
    """
    if value is None:
        raise ValueError(error_message)
    else:
        return value


def unwrap_type(value: T | None, t: type[T], error_message: str = "unexpected None when unwrapping value") -> T:
    """
    Will unwrap the given value or raise a TypeError if it is not of the expected type

    :param value: the optional value to unwrap
    :param t: the expected type of the value
    :param error_message: the error message when failing to unwrap
    """
    if value is None:
        raise ValueError(error_message)
    if not isinstance(value, t):
        raise ValueError(f"Expected {t}, got {type(value)}")
    return value


async def update_atomic_symlink(link_path: pathlib.Path, target_path: pathlib.Path | str) -> None:
    """Atomically update or create a symbolic link at link_path pointing to target_path."""
    target_str = str(target_path)

    # Generate a temporary path name for the new link
    link_dir = link_path.parent
    temp_link_path = link_dir / f".{link_path.name}.{uuid.uuid4()}.tmp"

    try:
        await aiofiles.os.symlink(target_str, temp_link_path)
        # Atomically rename the temporary link to the final link path
        # This overwrites link_path if it exists
        await aiofiles.os.rename(temp_link_path, link_path)
        _LOGGER.info(f"Atomically updated symlink {link_path} -> {target_str}")
    except Exception as e:
        # Don't bother with _LOGGER.exception here
        _LOGGER.error(f"Failed to update atomic symlink {link_path} -> {target_str}: {e}")
        # Clean up temporary link if rename failed
        try:
            await aiofiles.os.remove(temp_link_path)
        except FileNotFoundError:
            # TODO: Use with contextlib.suppress(FileNotFoundError) for these sorts of blocks?
            pass
        raise


def user_releases(asf_uid: str, releases: Sequence[models.Release]) -> list[models.Release]:
    """Return a list of releases for which the user is a committee member or committer."""
    # TODO: This should probably be a session method instead
    user_releases = []
    for release in releases:
        if release.committee is None:
            continue
        if (asf_uid in release.committee.committee_members) or (asf_uid in release.committee.committers):
            user_releases.append(release)
    return user_releases


def validate_as_type(value: Any, t: type[T]) -> T:
    """Validate the given value as the given type."""
    if not isinstance(value, t):
        raise ValueError(f"Expected {t}, got {type(value)}")
    return value


def validate_vote_duration(form: wtforms.Form, field: wtforms.IntegerField) -> None:
    """Checks if the value is 0 or between 72 and 144."""
    if not ((field.data == 0) or (72 <= field.data <= 144)):
        raise wtforms.validators.ValidationError("Minimum voting period must be 0 hours, or between 72 and 144 hours")


def version_name_error(version_name: str) -> str | None:
    """Check if the given version name is valid."""
    if version_name == "":
        return "Must not be empty"
    if version_name.lower() == "version":
        return "Must not be 'version'"
    if not re.match(r"^[a-zA-Z0-9]", version_name):
        return "Must start with a letter or number"
    if not re.search(r"[a-zA-Z0-9]$", version_name):
        return "Must end with a letter or number"
    if re.search(r"[+.-]{2,}", version_name):
        return "Must not contain multiple consecutive plus, full stop, or hyphen"
    if not re.match(r"^[a-zA-Z0-9+.-]+$", version_name):
        return "Must contain only letters, numbers, plus, full stop, or hyphen"
    return None


def _generate_hexdump(data: bytes) -> str:
    """Generate a formatted hexdump string from bytes."""
    hex_lines = []
    for i in range(0, len(data), 16):
        chunk = data[i : i + 16]
        hex_part = binascii.hexlify(chunk).decode("ascii")
        hex_part = hex_part.ljust(32)
        hex_part_spaced = " ".join(hex_part[j : j + 2] for j in range(0, len(hex_part), 2))
        ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
        line_num = f"{i:08x}"
        hex_lines.append(f"{line_num}  {hex_part_spaced}  |{ascii_part}|")
    return "\n".join(hex_lines)


def _get_dict_to_list_inner_type_adapter(source_type: Any, key: str) -> pydantic.TypeAdapter[dict[Any, Any]]:
    root_adapter = pydantic.TypeAdapter(source_type)
    schema = root_adapter.core_schema

    # support further nesting of model classes
    if schema["type"] == "definitions":
        schema = schema["schema"]

    assert schema["type"] == "list"
    assert (item_schema := schema["items_schema"])
    assert item_schema["type"] == "model"
    assert (cls := item_schema["cls"])  # noqa: RUF018

    fields = cls.model_fields

    assert (key_field := fields.get(key))  # noqa: RUF018
    assert (other_fields := {k: v for k, v in fields.items() if k != key})  # noqa: RUF018

    model_name = f"{cls.__name__}Inner"

    # Create proper field definitions for create_model
    inner_model = pydantic.create_model(model_name, **{k: (v.annotation, v) for k, v in other_fields.items()})  # type: ignore
    return pydantic.TypeAdapter(dict[Annotated[str, key_field], inner_model])  # type: ignore


def _get_dict_to_list_validator(inner_adapter: pydantic.TypeAdapter[dict[Any, Any]], key: str) -> Any:
    def validator(val: Any) -> Any:
        import pydantic.fields as fields

        if isinstance(val, dict):
            validated = inner_adapter.validate_python(val)

            # need to get the alias of the field in the nested model
            # as this will be fed into the actual model class
            def get_alias(field_name: str, field_infos: Mapping[str, fields.FieldInfo]) -> Any:
                field = field_infos[field_name]
                return field.alias if field.alias else field_name

            return [
                {key: k, **{get_alias(f, v.model_fields): getattr(v, f) for f in v.model_fields}}
                for k, v in validated.items()
            ]

        return val

    return validator
