#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Module to update db migration information in Airflow
"""

from __future__ import annotations

import os
import re
import textwrap
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING

from alembic.script import ScriptDirectory
from rich.console import Console
from tabulate import tabulate

from airflow import __version__ as airflow_version
from airflow.providers.fab import __version__ as fab_version
from airflow.utils.db import _get_alembic_config

if TYPE_CHECKING:
    from alembic.script import Script

console = Console(width=400, color_system="standard")

airflow_version = re.match(r"(\d+\.\d+\.\d+).*", airflow_version).group(1)  # type: ignore
fab_version = re.match(r"(\d+\.\d+\.\d+).*", fab_version).group(1)  # type: ignore
project_root = Path(__file__).parents[2].resolve()


def replace_text_between(file: Path, start: str, end: str, replacement_text: str):
    original_text = file.read_text()
    leading_text = original_text.split(start)[0]
    trailing_text = original_text.split(end)[1]
    file.write_text(leading_text + start + replacement_text + end + trailing_text)


def wrap_backticks(val):
    def _wrap_backticks(x):
        return f"``{x}``"

    return ",\n".join(map(_wrap_backticks, val)) if isinstance(val, (tuple, list)) else _wrap_backticks(val)


def update_doc(file, data, app):
    replace_text_between(
        file=file,
        start=" .. Beginning of auto-generated table\n",
        end=" .. End of auto-generated table\n",
        replacement_text="\n"
        + tabulate(
            headers={
                "revision": "Revision ID",
                "down_revision": "Revises ID",
                "version": f"{app.title()} Version",
                "description": "Description",
            },
            tabular_data=data,
            tablefmt="grid",
            stralign="left",
            disable_numparse=True,
        )
        + "\n\n",
    )


def has_version(content, app):
    if app == "airflow":
        return re.search(r"^airflow_version\s*=.*", content, flags=re.MULTILINE) is not None
    return re.search(r"^fab_version\s*=.*", content, flags=re.MULTILINE) is not None


def insert_version(old_content, file, app):
    if app == "airflow":
        new_content = re.sub(
            r"(^depends_on.*)",
            lambda x: f'{x.group(1)}\nairflow_version = "{airflow_version}"',
            old_content,
            flags=re.MULTILINE,
        )
    else:
        new_content = re.sub(
            r"(^depends_on.*)",
            lambda x: f'{x.group(1)}\nfab_version = "{fab_version}"',
            old_content,
            flags=re.MULTILINE,
        )
    file.write_text(new_content)


def revision_suffix(rev: Script):
    if rev.is_head:
        return " (head)"
    if rev.is_base:
        return " (base)"
    if rev.is_merge_point:
        return " (merge_point)"
    if rev.is_branch_point:
        return " (branch_point)"
    return ""


def ensure_version(revisions: Iterable[Script], app):
    for rev in revisions:
        if TYPE_CHECKING:  # For mypy
            assert rev.module.__file__ is not None
        file = Path(rev.module.__file__)
        content = file.read_text()
        if not has_version(content, app=app):
            insert_version(content, file, app=app)


def get_revisions(app="airflow") -> Iterable[Script]:
    if app == "airflow":
        config = _get_alembic_config()
        script = ScriptDirectory.from_config(config)
        yield from script.walk_revisions()
    else:
        from airflow.providers.fab.auth_manager.models.db import FABDBManager

        script = FABDBManager(session="").get_script_object()
        yield from script.walk_revisions()


def update_docs(revisions: Iterable[Script], app="airflow"):
    doc_data = []
    for rev in revisions:
        app_revision = rev.module.airflow_version if app == "airflow" else rev.module.fab_version
        doc_data.append(
            dict(
                revision=wrap_backticks(rev.revision) + revision_suffix(rev),
                down_revision=wrap_backticks(rev.down_revision),
                version=wrap_backticks(app_revision),  # type: ignore
                description="\n".join(textwrap.wrap(rev.doc, width=60)),
            )
        )
    if app == "fab":
        filepath = project_root / "providers" / "fab" / "docs" / "migrations-ref.rst"
    else:
        filepath = project_root / "airflow-core" / "docs" / "migrations-ref.rst"

    update_doc(
        file=filepath,
        data=doc_data,
        app=app,
    )


def ensure_mod_prefix(mod_name, idx, version):
    parts = [f"{idx + 1:04}", *version]
    match = re.match(r"([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)_(.+)", mod_name)
    if match:
        # previously standardized file, rebuild the name
        parts.append(match.group(5))
    else:
        # new migration file, standard format
        match = re.match(r"([a-z0-9]+)_(.+)", mod_name)
        if match:
            parts.append(match.group(2))
    return "_".join(parts)


def ensure_filenames_are_sorted(revisions, app):
    renames = []
    is_branched = False
    unmerged_heads = []
    for idx, rev in enumerate(revisions):
        mod_path = Path(rev.module.__file__)
        if app == "airflow":
            version = rev.module.airflow_version.split(".")[0:3]  # only first 3 tokens
        else:
            version = rev.module.fab_version.split(".")[0:3]  # only first 3 tokens
        correct_mod_basename = ensure_mod_prefix(mod_path.name, idx, version)
        if mod_path.name != correct_mod_basename:
            renames.append((mod_path, Path(mod_path.parent, correct_mod_basename)))
        if is_branched and rev.is_merge_point:
            is_branched = False
        if rev.is_branch_point:
            is_branched = True
        elif rev.is_head:
            unmerged_heads.append(rev.revision)
    if is_branched:
        head_prefixes = [x[0:4] for x in unmerged_heads]
        alembic_command = (
            "alembic merge -m 'merge heads " + ", ".join(head_prefixes) + "' " + " ".join(unmerged_heads)
        )
        raise SystemExit(
            "You have multiple alembic heads; please merge them with by running `alembic merge` command under "
            f'"airflow" directory (where alembic.ini located) and re-run pre-commit. '
            f"It should fail once more before succeeding.\nhint: `{alembic_command}`"
        )
    for old, new in renames:
        os.rename(old, new)


def correct_mismatching_revision_nums(revisions: Iterable[Script]):
    revision_pattern = r'revision = ["\']([a-fA-F0-9]+)["\']'
    down_revision_pattern = r'down_revision = ["\']([a-fA-F0-9]+)["\']'
    revision_id_pattern = r"Revision ID: ([a-fA-F0-9]+)"
    revises_id_pattern = r"Revises: ([a-fA-F0-9]+)"
    for rev in revisions:
        if TYPE_CHECKING:  # For mypy
            assert rev.module.__file__ is not None
        file = Path(rev.module.__file__)
        content = file.read_text()
        revision_match = re.search(
            revision_pattern,
            content,
        )
        if revision_match is None:
            raise RuntimeError(f"revision = not found in {file}")
        revision_id_match = re.search(revision_id_pattern, content)
        if revision_id_match is None:
            raise RuntimeError(f"Revision ID: not found in {file}")
        new_content = content.replace(revision_id_match.group(1), revision_match.group(1), 1)
        down_revision_match = re.search(down_revision_pattern, new_content)
        revises_id_match = re.search(revises_id_pattern, new_content)
        if down_revision_match:
            if revises_id_match is None:
                raise RuntimeError(f"Revises: not found in {file}")
            new_content = new_content.replace(revises_id_match.group(1), down_revision_match.group(1), 1)
        file.write_text(new_content)


if __name__ == "__main__":
    apps = ["airflow", "fab"]
    for app in apps:
        console.print(f"[bright_blue]Updating migration reference for {app}")
        revisions = list(reversed(list(get_revisions(app))))
        console.print(f"[bright_blue]Making sure {app} version updated")
        ensure_version(revisions=revisions, app=app)
        console.print("[bright_blue]Making sure there's no mismatching revision numbers")
        correct_mismatching_revision_nums(revisions=revisions)
        revisions = list(reversed(list(get_revisions(app=app))))
        console.print("[bright_blue]Making sure filenames are sorted")
        ensure_filenames_are_sorted(revisions=revisions, app=app)
        revisions = list(get_revisions(app=app))
        console.print("[bright_blue]Updating documentation")
        update_docs(revisions=revisions, app=app)
        console.print("[green]Migrations OK")
