#!/usr/bin/env python

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to convert type hints to follow PEP-0585

For more information, see https://peps.python.org/pep-0585

To run from the repository's root directory:
    python convert-types.py
"""

from __future__ import annotations

from collections.abc import Callable, Iterator
import difflib
from glob import glob
import logging
import re
import sys
from typing import NamedTuple, TypeVar

# TODO:
# - False positives with lambdas and dict comprehensions due to the `:` misinterpreted as a type hint
# - Sort imports case insensitive (e.g. PIL, Flask)
# - Type hint arguments can be lists like `Callable[[a, b], c]`

# Cases not covered:
# - Multi-line imports like `from M import (\nA,\nB,\n)`
# - Importing `typing` directly like `import typing` and `x: typing.Any`
# - Parsing types with `|` syntax like Union or Optional
# - typing.re.Match --> re.Match
# - typing.re.Pattern --> re.Pattern


BUILTIN_TYPES = {"Tuple", "List", "Dict", "Set", "FrozenSet", "Type"}
COLLECTIONS_TYPES = {"Deque", "DefaultDict", "OrderedDict", "Counter", "ChainMap"}
COLLECTIONS_ABC_TYPES = {
    "Awaitable",
    "Coroutine",
    "AsyncIterable",
    "AsyncIterator",
    "AsyncGenerator",
    "Iterable",
    "Iterator",
    "Generator",
    "Reversible",
    "Container",
    "Collection",
    "Callable",
    "AbstractSet",
    "MutableSet",
    "Mapping",
    "MutableMapping",
    "Sequence",
    "MutableSequence",
    "ByteString",
    "MappingView",
    "KeysView",
    "ItemsView",
    "ValuesView",
}
CONTEXTLIB_TYPES = {"ContextManager", "AsyncContextManager"}
RE_TYPES = {"Match", "Pattern"}

RENAME_TYPES = {
    "Tuple": "tuple",  # builtin
    "List": "list",  # builtin
    "Dict": "dict",  # builtin
    "Set": "set",  # builtin
    "FrozenSet": "frozenset",  # builtin
    "Type": "type",  # builtin
    "Deque": "deque",  # collections
    "DefaultDict": "defaultdict",  # collections
    "AbstractSet": "Set",  # collections.abc
    "ContextManager": "AbstractContextManager",  # contextlib
    "AsyncContextManager": "AbstractAsyncContextManager",  # contextlib
}

# Parser a = String -> (a, String)
a = TypeVar("a")
Parser = Callable[[str], tuple[a, str]]


class TypeHint(NamedTuple):
    name: str
    args: list[TypeHint]

    def __repr__(self) -> str:
        match (self.name, self.args):
            case ("Optional", [x]):
                return f"{x} | None"
            case ("Union", args):
                return " | ".join(map(str, args))
            case (name, []):
                return name
            case (name, args):
                return f"{name}[{', '.join(map(str, args))}]"

    def patch(self, types: set[str]) -> TypeHint:
        if self.name in types:
            name = RENAME_TYPES.get(self.name, self.name)
        else:
            name = self.name
        return TypeHint(name, [arg.patch(types) for arg in self.args])


def patch_file(file_path: str, dry_run: bool = False, quiet: bool = False) -> None:
    with open(file_path) as f:
        before = f.read()
    try:
        lines = [line.rstrip() for line in before.splitlines()]
        if types := find_typing_imports(lines):
            lines = insert_import_annotations(lines)
            lines = [patched for line in lines for patched in patch_imports(line)]
            lines = sort_imports(lines)
            after = patch_type_hints("\n".join(lines), types) + "\n"
            if before == after:
                return

            if not dry_run:
                with open(file_path, "w") as f:
                    f.write(after)
                print(file_path)
            elif not quiet:
                print(f"| {file_path}")
                print(f"+--{'-' * len(file_path)}")
                diffs = difflib.context_diff(
                    before.splitlines(keepends=True),
                    after.splitlines(keepends=True),
                    fromfile="Before changes",
                    tofile="After changes",
                    n=1,
                )
                sys.stdout.writelines(diffs)
                print(f"+{'=' * 100}")
                print("| Press [ENTER] to continue to the next file")
                input()
    except Exception:
        logging.exception(f"Could not process file: {file_path}")


def insert_import_annotations(lines: list[str]) -> list[str]:
    new_import = "from __future__ import annotations"
    if new_import in lines:
        return lines

    match find_import(lines):
        case None:
            return lines
        case i:
            if lines[i].startswith("from __future__ import "):
                return lines[:i] + [new_import] + lines[i:]
            return lines[:i] + [new_import, ""] + lines[i:]


def find_typing_imports(lines: list[str]) -> set[str]:
    return {
        name.strip()
        for line in lines
        if line.startswith("from typing import ")
        for name in line.split("import")[1].split(",")
    }


def find_import(lines: list[str]) -> int | None:
    for i, line in enumerate(lines):
        if line.startswith(("import ", "from ")):
            return i
    return None


def get_imports_group(lines: list[str]) -> tuple[list[str], list[str]]:
    for i, line in enumerate(lines):
        if not line.strip() or line.startswith("#"):
            return (lines[:i], lines[i:])
    return ([], lines)


def import_name(line: str) -> str:
    match line.split():
        case ["import", name, *_]:
            return name
        case ["from", name, "import", *_]:
            return name
    raise ValueError(f"not an import: {line}")


def sort_imports(lines: list[str]) -> list[str]:
    match find_import(lines):
        case None:
            return lines
        case i:
            (imports, left) = get_imports_group(lines[i:])
            if imports:
                return lines[:i] + sorted(imports, key=import_name) + sort_imports(left)
            return left


def patch_imports(line: str) -> Iterator[str]:
    if not line.startswith("from typing import "):
        yield line
        return

    types = find_typing_imports([line])
    collections_types = types.intersection(COLLECTIONS_TYPES)
    collections_abc_types = types.intersection(COLLECTIONS_ABC_TYPES)
    contextlib_types = types.intersection(CONTEXTLIB_TYPES)
    re_types = types.intersection(RE_TYPES)
    typing_types = (
        types
        - BUILTIN_TYPES
        - COLLECTIONS_TYPES
        - COLLECTIONS_ABC_TYPES
        - CONTEXTLIB_TYPES
        - RE_TYPES
        - {"Optional", "Union"}
    )

    rename = lambda name: RENAME_TYPES.get(name, name)
    if collections_types:
        names = sorted(map(rename, collections_types))
        yield f"from collections import {', '.join(names)}"
    if collections_abc_types:
        names = sorted(map(rename, collections_abc_types))
        yield f"from collections.abc import {', '.join(names)}"
    if contextlib_types:
        names = sorted(map(rename, contextlib_types))
        yield f"from contextlib import {', '.join(names)}"
    if re_types:
        names = sorted(map(rename, re_types))
        yield f"from re import {', '.join(names)}"
    if typing_types:
        names = sorted(map(rename, typing_types))
        yield f"from typing import {', '.join(names)}"


def patch_type_hints(txt: str, types: set[str]) -> str:
    if m := re.search(rf"(?:->|:) *(\w+)", txt):
        (typ, left) = parse_type_hint(txt[m.start(1) :])
        return f"{txt[:m.start(1)]}{typ.patch(types)}{patch_type_hints(left, types)}"
    return txt


# Parser combinators
def parse_text(src: str, txt: str) -> tuple[str, str]:
    if src.startswith(txt):
        return (src[: len(txt)], src[len(txt) :])
    raise SyntaxError("text")


def parse_identifier(src: str) -> tuple[str, str]:
    if m := re.search(r"[\w\._]+", src):
        return (m.group(), src[m.end() :])
    raise SyntaxError("name")


def parse_zero_or_more(src: str, parser: Parser[a]) -> tuple[list[a], str]:
    try:
        (x, src) = parser(src)
        (xs, src) = parse_zero_or_more(src, parser)
        return ([x] + xs, src)
    except SyntaxError:
        return ([], src)


def parse_comma_separated(src: str, parser: Parser[a]) -> tuple[list[a], str]:
    def parse_next(src: str) -> tuple[a, str]:
        (_, src) = parse_text(src, ",")
        (_, src) = parse_zero_or_more(src, lambda src: parse_text(src, " "))
        return parser(src)

    try:
        (x, src) = parser(src)
        (xs, src) = parse_zero_or_more(src, parse_next)
        return ([x] + xs, src)
    except SyntaxError:
        return ([], src)


def parse_type_hint(src: str) -> tuple[TypeHint, str]:
    (name, src) = parse_identifier(src)
    try:
        (_, src) = parse_text(src, "[")
        (args, src) = parse_comma_separated(src, parse_type_hint)
        (_, src) = parse_text(src, "]")
        return (TypeHint(name, args), src)
    except SyntaxError:
        return (TypeHint(name, []), src)


def run(patterns: list[str], dry_run: bool = False, quiet: bool = False):
    for pattern in patterns:
        for filename in glob(pattern, recursive=True):
            patch_file(filename, dry_run, quiet)


if __name__ == "__main__":
    import argparse

    assert sys.version_info.major == 3, "Requires Python 3"
    assert sys.version_info.minor >= 10, "Requires Python >= 3.10 for pattern matching"

    parser = argparse.ArgumentParser()
    parser.add_argument("patterns", nargs="*", default=["**/*.py"])
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args()

    run(**args.__dict__)
